<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.9.3">Jekyll</generator><link href="https://johanwind.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://johanwind.github.io/" rel="alternate" type="text/html" /><updated>2023-05-08T19:32:53+00:00</updated><id>https://johanwind.github.io/feed.xml</id><title type="html">The Good Minima</title><subtitle>A blog about implicit biases in deep learning.</subtitle><entry><title type="html">How the RWKV language model works</title><link href="https://johanwind.github.io/2023/03/23/rwkv_details.html" rel="alternate" type="text/html" title="How the RWKV language model works" /><published>2023-03-23T00:00:00+00:00</published><updated>2023-03-23T00:00:00+00:00</updated><id>https://johanwind.github.io/2023/03/23/rwkv_details</id><content type="html" xml:base="https://johanwind.github.io/2023/03/23/rwkv_details.html">&lt;p&gt;In this post, I will explain the details of how RWKV generates text. For a high level overview of what RWKV is and what is so special about it, check out &lt;a href=&quot;/2023/03/23/rwkv_overview.html&quot;&gt;the other post about RWKV&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;To explain exactly how RWKV works, I think it is easiest to look at a simple implementation of it. The following ~100 line code (based on &lt;a href=&quot;https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py&quot;&gt;RWKV in 150 lines&lt;/a&gt;) is a minimal implementation of a relatively small (430m parameter) RWKV model which generates text.&lt;/p&gt;

&lt;p&gt;&lt;details&gt; &lt;summary&gt;Minimal RWKV code&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;load&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch_load&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Only for loading the model weights
&lt;/span&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tokenizers&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Tokenizer&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;time_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_num&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_den&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;decay&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bonus&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;wkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;last_num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bonus&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;      \
          &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;last_den&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bonus&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;rwkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wkv&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;decay&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;den&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;decay&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_den&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wout&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rwkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;den&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;channel_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;vk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maximum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;vk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;RWKV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;prefix&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;keys&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;startswith&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prefix&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'emb'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.0.ln0'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N_LAYER&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.ln1'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][:&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][:&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.att'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.ln2'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;channel_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.ffn'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'ln_out'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'head'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Softmax of x
&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;##########################################################################################################
&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;sample_probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;top_p&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.85&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sorted_probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[::&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cumulative_probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cumsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sorted_probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cutoff&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sorted_probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;argmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cumulative_probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;top_p&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cutoff&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# Available at https://huggingface.co/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MODEL_FILE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'/data/rwkv/RWKV-4-Pile-430M-20220808-8066.pth'&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;N_LAYER&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;24&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;N_EMBD&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Loading &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MODEL_FILE&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch_load&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MODEL_FILE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;map_location&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'cpu'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;keys&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'.time_'&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squeeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;numpy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# convert to f32 type
&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Available at https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tokenizer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Tokenizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_file&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;/data/rwkv/20B_tokenizer.json&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Preprocessing context'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.&quot;&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N_LAYER&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N_EMBD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tokenizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;encode&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;RWKV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;end&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;token&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sample_probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tokenizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;decode&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;end&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;flush&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;RWKV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;&lt;/p&gt;

&lt;p&gt;To avoid hiding complexity, the model computation itself is written entirely in python, with numpy for matrix / vector operations. However, I needed to use &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.load&lt;/code&gt; to load the model weights from a file, and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tokenizers.Tokenizer&lt;/code&gt; to make the text into tokens the model can work with.&lt;/p&gt;

&lt;h2 id=&quot;text-generation-with-rwkv&quot;&gt;Text generation with RWKV&lt;/h2&gt;
&lt;p&gt;The code uses RWKV to continue the following text:&lt;/p&gt;

&lt;p&gt;“In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.”&lt;/p&gt;

&lt;p&gt;We first need to convert this text into a series of tokens (numbers from 0 to 50276 representing words/symbols/tokens in our vocabulary). That is not the focus of this blog post, so I just do it with an external library &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tokenizer.encode(context).ids&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Next, we need to process this sequence of tokens into an RWKV state. Essentially, RWKV represents a function which takes a token and a state, and outputs a probability distribution over the next token, and a new state. Of course, the function also depends on the RWKV model parameters, but since we use a trained model (downloaded from &lt;a href=&quot;https://huggingface.co/BlinkDL/rwkv-4-pile-430m&quot;&gt;here&lt;/a&gt;), we view those parameters as fixed. To convert the text to a state, we just initialize the state to all zeros, and feed the tokens through the RWKV function one by one.&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;state = np.zeros((N_LAYER, 4, N_EMBD), dtype=np.float32)
for token in tokenizer.encode(context).ids:
    probs, state = RWKV(weights, token, state)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now the variable &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;state&lt;/code&gt; contains a state representation of our input text, and the variable “probs” contain the probability distribution the model predicts for the next token.&lt;/p&gt;

&lt;p&gt;We can now simply sample the probability distribution (in practice, we avoid low probability tokens in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sample_probs()&lt;/code&gt;) and add another token to the text. Then we feed the new token into RWKV and repeat.&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;for i in range(100):
    token = sample_probs(probs)
    print(tokenizer.decode([token]), end=&quot;&quot;, flush=True)
    probs, state = RWKV(weights, token, state)       
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;A typical, generated continuation is:&lt;/p&gt;

&lt;p&gt;“They’re just like us. They use Tibetan for communication, and for a different reason – they use a language that they’re afraid to use. To protect their secret, they prefer to speak a different language to the local public.”&lt;/p&gt;

&lt;p&gt;Of course, larger models will perform better than this relatively small 430m RWKV.&lt;/p&gt;

&lt;h2 id=&quot;what-goes-on-inside-rwkv&quot;&gt;What goes on inside RWKV()&lt;/h2&gt;

&lt;p&gt;The first thing RWKV does is look up the embedding vector of the input token. I.e &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x = params('emb')[0][token]&lt;/code&gt;. Here &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;params('emb')[0]&lt;/code&gt; is simply a \(50277 \times 1024\) matrix, and we extract a row.&lt;/p&gt;

&lt;p&gt;The next line &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x = layer_norm(x, *params('blocks.0.ln0'))&lt;/code&gt; requires me to explain what a Layer Normalization is. The easiest way is to just show the definition:&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;layer_norm = lambda x, w, b : (x - np.mean(x)) / np.std(x) * w + b&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;The intuition is that it normalizes a vector x to zero mean and unit variance, and then scales and offsets that. Note that the scale &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;w&lt;/code&gt; and offset &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b&lt;/code&gt; are 1024-dimensional vectors, which are learned model parameters.&lt;/p&gt;

&lt;p&gt;Now we get to the main part of the model. Which is split into 24 layers, applied sequentially.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N_LAYER&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.ln1'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][:&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][:&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.att'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.ln2'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;channel_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'blocks.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.ffn'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dx&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Note that we are only adding updates to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; like &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x = x + dx&lt;/code&gt;, this is called using “residual connections”. Each time we make a copy of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt;, we feed it through a layer normalization before mixing it.  Each layer has two mixing functions: a “time mixing” part and a “channel mixing” part. In a typical transformer, the “time mixing” would be done by multi head attention, and the “channel mixing” would be done by a simple feed forward network. RWKV does something a bit different, which we’ll explain in the next sections.&lt;/p&gt;

&lt;h3 id=&quot;channel-mixing&quot;&gt;Channel mixing&lt;/h3&gt;
&lt;p&gt;I’ll start with channel mixing, since it’s the simpler one of the two mixing functions.&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;channel_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;vk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maximum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;vk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;The channel mixing layer takes an input &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; corresponding to this token, and the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; corresponding to the previous token, which we call &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;last_x&lt;/code&gt;. &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;last_x&lt;/code&gt; was stored in this RWKV layer’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;state&lt;/code&gt;. The rest of the inputs are learned RWKV parameters.&lt;/p&gt;

&lt;p&gt;First, we linearly interpolate &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;last_x&lt;/code&gt;, using learned weights. We run this interpolated &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; as input to a 2 layer feed forward network with squared relu activation, and finally multiply with the sigmoid activations of another feed forward network (in classical RNN terms, this would be called gating).&lt;/p&gt;

&lt;p&gt;Note that in terms of memory usage, the matrices &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Wk,Wr,Wv&lt;/code&gt; hold almost all the parameters (the smallest of them is a \(1024\times 1024\) matrix, while the other variables are just 1024-dimensional vectors). And the matrix multiplications (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;@&lt;/code&gt; in python) contribute the vast majority of required computations.&lt;/p&gt;

&lt;h3 id=&quot;time-mixing&quot;&gt;Time mixing&lt;/h3&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;time_mixing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_num&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_den&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;decay&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bonus&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mix_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;wkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;last_num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bonus&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;      \
          &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;last_den&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bonus&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;rwkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wkv&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;decay&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;den&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;decay&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;last_den&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Wout&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rwkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;den&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;The time mixing starts similarly to the channel mixing, by interpolating this token’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; with the last token’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt;. We then apply learned \(1024\times 1024\) matrices to get “key”, “value” and “receptance” vectors.&lt;/p&gt;

&lt;p&gt;The next part is where the magic happens.&lt;/p&gt;

&lt;h4 id=&quot;the-rwkv-attention&quot;&gt;The “RWKV attention”&lt;/h4&gt;
&lt;p&gt;Before getting to the core of the mechanism, we will make the observation that while the variables going into the attention mechanism are all 1024-dimensional (we say they have 1024 channels), all channels are computed independently of each other. We will therefore just look at what happens to a single channel, treating the variables as scalars.&lt;/p&gt;

&lt;p&gt;Now, let us look at the variable &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num&lt;/code&gt;. To make math notations cleaner, let’s rename &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;den&lt;/code&gt; to \(\alpha\) and \(\beta\). Both \(\alpha\) and \(\beta\) are stored in the RWKV state. For each new token, \(\alpha\) is calculated as \(\alpha_i = e^{-w} \alpha_{i-1} +e^{k_i} v_i\), where \(i\) is the index of the current token. We defined &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;w = exp(decay)&lt;/code&gt;, note that &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;w&lt;/code&gt; is always positive.&lt;/p&gt;

&lt;p&gt;By induction, we have \(\alpha_i = \sum_{j=1}^i e^{-(i-j)w+k_j} v_j\). Similarly, \(\beta_i = \sum_{j=1}^i e^{-(i-j)w+k_j}\). Note that \(\alpha_i\) looks like a weighted sum of the \(v_j\), while \(\beta_i\) is just the sum of weights. So \(\frac{\alpha_i}{\beta_i}\) becomes a weighted average of \(v_j\).&lt;/p&gt;

&lt;p&gt;Plugging in the formulas for \(\alpha_{i-1}\) and \(\beta_{i-1}\) into the definition of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;wkv&lt;/code&gt;, and denoting &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;bonus&lt;/code&gt; by \(u\), we get&lt;/p&gt;

\[\text{wkv}_i = \frac{ \sum_{j=1}^{i-1} e^{-(i-1-j)w+k_j} v_j + e^{u+k_i} v_i }{\sum_{j=1}^{i-1} e^{-(i-1-j)w+k_j} + e^{u+k_i}}.\]

&lt;p&gt;So \(\text{wkv}\) is a weighted average of \(v\) with weights according to \(k\), but also the current \(v_i\) is given a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;bonus&lt;/code&gt; (\(u\)) additional weight, and previous \(v_j\) are given geometrically smaller weights the further away they are.&lt;/p&gt;

&lt;p&gt;For reference, standard transformer attention takes “query”, “key” and “value” vectors \(q,k,v\) and outputs&lt;/p&gt;

\[\frac{\sum_{j=1}^i e^{q_i^\top k_j} v_j}{\sum_{j=1}^i e^{q_i^\top k_j}}.\]

&lt;p&gt;After calculating &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;wkv&lt;/code&gt;, the time mixing multiplies by the “receptance” &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sigmoid(r)&lt;/code&gt;. It does a final linear transformation before returning the result.&lt;/p&gt;

&lt;h3 id=&quot;converting-to-output-probabilities&quot;&gt;Converting to output probabilities&lt;/h3&gt;
&lt;p&gt;After going through the 24 layers of time mixing and channel mixing, we need to convert the final output to predicted probabilities for the next token.&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'ln_out'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'head'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Softmax of x
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;First, we do a layer normalization. Then, we multiply by a \(50277 \times 1024\) matrix &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;params('head')[0]&lt;/code&gt; given by the RWKV parameters, giving us a 50277-dimensional vector. To get a probability distribution over tokens (i.e. a 50277-dimensional, non-negative vector which sums to 1), we run our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; through a “softmax” function. The softmax of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; is just &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;exp(x)/sum(exp(x))&lt;/code&gt;. However, calculating &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;exp(x)&lt;/code&gt; can cause numerical overflows, so we calculate the equivalent function &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;exp(x-max(x))/sum(exp(x-max(x)))&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;That’s it! Now you know exactly how RWKV works for generating text.&lt;/p&gt;

&lt;h2 id=&quot;practical-considerations&quot;&gt;Practical considerations&lt;/h2&gt;
&lt;p&gt;In practice, there are some issues which I ignored in my simplified code. Most importantly, in practice, we care a lot about the performance / run-time of the code. This leads us to run RWKV in parallel on GPUs, use specialized GPU code written in CUDA, use 16-bit floating point numbers, and more.&lt;/p&gt;

&lt;h3 id=&quot;numerical-issues&quot;&gt;Numerical issues&lt;/h3&gt;
&lt;p&gt;The largest number a 16-bit floating point number (float16) can represent is 65 504, anything above that overflows, which is bad. Most of the code has no problems with this, partially because the Layer Normalizations keep values in a reasonable range. However, the RWKV attention contains exponentially large numbers (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;exp(bonus + k)&lt;/code&gt;). In practice, the RWKV attention is implemented in a way where we factor out an exponential factor from &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;den&lt;/code&gt; to keep everything within float16 range. See for example the time_mixing function in &lt;a href=&quot;https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py&quot;&gt;RWKV in 150 lines&lt;/a&gt;.&lt;/p&gt;

&lt;h3 id=&quot;training&quot;&gt;Training&lt;/h3&gt;
&lt;p&gt;We simply loaded a pretrained model in our example. To train the model, one would calculate the cross entropy loss of the predicted probabilities on a long text (our example model was trained on &lt;a href=&quot;https://pile.eleuther.ai/&quot;&gt;the pile&lt;/a&gt;). Next, calculate the gradient of that loss with respect to all the RWKV parameters. That gradient is used to improve the parameters using a variant of Gradient Descent called Adam. Repeat for a long time, and you get a trained RWKV model.&lt;/p&gt;

&lt;h3 id=&quot;gpt-mode&quot;&gt;GPT-mode&lt;/h3&gt;
&lt;p&gt;My simplified code processes the tokens one by one, which is much slower than processing them in parallel, especially when running on GPUs. For inference, there is no way around this, as we need to sample a token before we can use it to calculate the next one. However, for training, all the text is already available. This lets us parallelize across tokens. Most of the code is fairly straightforward to parallelize like this, as there is little dependence through time. For example, all the expensive matrix multiplications work on each token independently, leading to good performance.&lt;/p&gt;

&lt;p&gt;However, the RWKV attention is inherently sequential. Fortunately, it has very little computation (on the order of 1024 times less than the matrix multiplications), so it should be fast. Sadly, pytorch does not have a good way of handling this sequential task, so the attention part becomes slow (even compared to the matrix multiplications). Therefore, I wrote optimized CUDA kernels for computing the RWKV attention, which has been my main contribution to the RWKV project.&lt;/p&gt;

&lt;p&gt;JAX has jax.lax.scan and jax.lax.associative_scan, which allows a pure JAX implementation to perform better than pure pytorch. However, I still estimate that JAX would lead to &lt;a href=&quot;https://discord.com/channels/992359628979568762/992363236370436136/1000838180674736158&quot;&gt;about 40% slower training compared to CUDA&lt;/a&gt; (that estimate may be outdated, as it was made for training a relatively small 1.5B model).&lt;/p&gt;

&lt;h1 id=&quot;contribute&quot;&gt;Contribute&lt;/h1&gt;
&lt;p&gt;RWKV is an open source community project. Join the &lt;a href=&quot;https://discord.gg/neCJHNcDCZ&quot;&gt;Discord&lt;/a&gt; and contribute! Or just ask questions or lurk.&lt;/p&gt;</content><author><name></name></author><summary type="html">In this post, I will explain the details of how RWKV generates text. For a high level overview of what RWKV is and what is so special about it, check out the other post about RWKV.</summary></entry><entry><title type="html">The RWKV language model: An RNN with the advantages of a transformer</title><link href="https://johanwind.github.io/2023/03/23/rwkv_overview.html" rel="alternate" type="text/html" title="The RWKV language model: An RNN with the advantages of a transformer" /><published>2023-03-23T00:00:00+00:00</published><updated>2023-03-23T00:00:00+00:00</updated><id>https://johanwind.github.io/2023/03/23/rwkv_overview</id><content type="html" xml:base="https://johanwind.github.io/2023/03/23/rwkv_overview.html">&lt;p&gt;For a while, I’ve been following and contributing to the RWKV language model, an open source large language model with great potential. As ChatGPT and large language models in general have gotten a lot of attention recently, I think it’s a good time to write about RWKV. In this post, I will try to explain what is so special about RWKV compared to most language models (transformers). &lt;a href=&quot;/2023/03/23/rwkv_details.html&quot;&gt;The other RWKV post&lt;/a&gt; is more technical, showing in detail how RWKV actually works (with a ~100 line minimal implementation).&lt;/p&gt;

&lt;p&gt;At a high level, the RWKV model is a clever RNN architecture that enables it to be trained like a transformer. So to explain RWKV, I need to explain RNNs and transformers first.&lt;/p&gt;

&lt;h1 id=&quot;rnns&quot;&gt;RNNs&lt;/h1&gt;
&lt;p&gt;Classically, the neural networks used for sequence (such as text) processing were RNNs (like LSTMs). RNNs take two inputs: a state vector and a token&lt;sup id=&quot;fnref:token&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:token&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;. It goes through the input sequence one token at a time, each token updating the state. We may for example use an RNN to process a text into a single state vector. This can then be used to classify the text into “positive” or “negative”. Or we may use the final state to predict the next token, which is how RNNs are used to generate text.&lt;/p&gt;

&lt;h1 id=&quot;transformers&quot;&gt;Transformers&lt;/h1&gt;
&lt;p&gt;Because of the sequential nature of RNNs, they are hard to massively parallelize across many GPUs. This motivated using an “attention” mechanism instead of sequential processing, resulting in an architecture called a transformer. A transformer processes all tokens at the same time, comparing each token to all previous tokens in parallel. Specifically, the attention calculates “key”, “value” and “query” vectors for each token, then contributions between all pairs of tokens are computed using those.&lt;/p&gt;

&lt;p&gt;In addition to being able to speed up training through massive parallelization, large transformers generally score better than RNNs on benchmarks.&lt;/p&gt;

&lt;p&gt;However, the attention mechanism scales quadratically with the length of the sequence to be processed. This effectively limits the model’s input size (or “context length”). Additionally, because of the attention mechanism, when generating text, we need to keep attention vectors for all previous tokens in memory. This requires much more memory than an RNN which only stores a single state.&lt;/p&gt;

&lt;h1 id=&quot;rwkv&quot;&gt;RWKV&lt;/h1&gt;
&lt;p&gt;RWKV combines the best features of RNNs and transformers. During training, we use the transformer type formulation of the architecture, which allows massive parallelization (with a sort of attention which scales linearly with the number of tokens). For inference, we use an equivalent formulation which works like an RNN with a state. This allows us to get the best of both worlds.&lt;/p&gt;

&lt;p&gt;So we basically have a model which trains like a transformer, except that long context length is not expensive. And during inference, we need substantially less memory and can implicitly handle “infinite” context length (though in practice, the model might have a hard time generalizing to much longer context lengths than it saw during training).&lt;/p&gt;

&lt;p&gt;OK, but what about the performance? Since RWKV an RNN, it is natural to think that it can’t perform as well as a transformer on benchmarks. Also, this just sounds like linear attention. None of the many previous linear time attention transformer architectures (like “Linformer”, “Nystromformer”, “Longformer”, “Performer”) seemed to take off.&lt;/p&gt;

&lt;h1 id=&quot;benchmarks&quot;&gt;Benchmarks&lt;/h1&gt;
&lt;p&gt;Well, RWKV seems to scale as well as SOTA transformers. At least up to 14 billion parameters.
&lt;img src=&quot;/images/rwkv_benchmark.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;h1 id=&quot;contribute&quot;&gt;Contribute&lt;/h1&gt;
&lt;p&gt;RWKV is an open source community project. Join the &lt;a href=&quot;https://discord.gg/neCJHNcDCZ&quot;&gt;Discord&lt;/a&gt; and contribute (or ask questions or whatever).&lt;/p&gt;

&lt;h1 id=&quot;cost-estimates-for-large-language-models&quot;&gt;Cost estimates for Large Language Models&lt;/h1&gt;
&lt;p&gt;When looking at RWKV 14B (14 billion parameters), it is easy to ask what happens when we scale to 175B like GPT-3. However, training a 175B model is expensive. Calculating the approximate training cost of a transformer-like architecture is actually straightforward.&lt;/p&gt;

&lt;p&gt;The bottleneck for training is essentially multiplying by all the parameters, and then adding that together, for each input token. With automatic differentiation, we can calculate the gradient with about another 2x that, for a total of 6 FLOPs per parameter per token. So a 14B model trained on 300 billion tokens takes about \(14B \times 300B \times 6 = 2.5 \times 10^{22}\) FLOPs. We use A100 GPUs for training. Using 16-bit floating point numbers, an A100 can theoretically do up to 312 TFLOPS, or about \(1.1\times 10^{18}\) FLOPs per hour. So we theoretically need at least 22 436 hours of A100 time to train. In practice, RWKV 14B was trained on 64 A100s in parallel, sacrificing a bit of performance for various reasons. RWKV 14B took about 3 months \(\approx 140\ 160\) A100 hours to train, thus achieving about 20% theoretical efficiency (since it took about 5x longer than the theoretical minimum). Recent versions can train RWKV 14B at around 50% theoretical efficency.&lt;/p&gt;

&lt;p&gt;As a rough price estimate, at the time of writing, the cheapest A100 cost at cloud-gpus.com was $0.79/h. Training the original 14B RWKV there would hence cost around $100k, but with the recent training code improvements we could reduce this to $40k. In practice, there are other considerations like ease of use, timeouts, multi-gpu communication speed, etc. Thus, one might want more high-end options like AWS at $4.096/h. RWKV was trained on compute donated by Stability and EleutherAI.&lt;/p&gt;

&lt;p&gt;Now you can imagine that training 10x more parameters and 10x more data will cost 100x more, making it prohibitively expensive.&lt;/p&gt;

&lt;h4 id=&quot;footnote&quot;&gt;Footnote&lt;/h4&gt;
&lt;div class=&quot;footnotes&quot; role=&quot;doc-endnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:token&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Before using a language model on a text, we &lt;em&gt;tokenize&lt;/em&gt; the text into tokens. Intuitively speaking, a token is basically a word. In practice, a tokenizer might split words into multiple tokens, has to handle special characters and punctuation, and employ some tricks like adding a token for “end of text”. The &lt;a href=&quot;https://huggingface.co/BlinkDL/rwkv-4-pile-14b&quot;&gt;14B RWKV model&lt;/a&gt; uses 50277 different tokens. &lt;a href=&quot;#fnref:token&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;</content><author><name></name></author><summary type="html">For a while, I’ve been following and contributing to the RWKV language model, an open source large language model with great potential. As ChatGPT and large language models in general have gotten a lot of attention recently, I think it’s a good time to write about RWKV. In this post, I will try to explain what is so special about RWKV compared to most language models (transformers). The other RWKV post is more technical, showing in detail how RWKV actually works (with a ~100 line minimal implementation).</summary></entry><entry><title type="html">94% on CIFAR-10 in 94 lines and 94 seconds</title><link href="https://johanwind.github.io/2022/12/28/cifar_94.html" rel="alternate" type="text/html" title="94% on CIFAR-10 in 94 lines and 94 seconds" /><published>2022-12-28T00:00:00+00:00</published><updated>2022-12-28T00:00:00+00:00</updated><id>https://johanwind.github.io/2022/12/28/cifar_94</id><content type="html" xml:base="https://johanwind.github.io/2022/12/28/cifar_94.html">&lt;p&gt;The following code scores 94.02% test accuracy (mean of 40 runs) in 94 lines and less than 94 seconds training time (loading / validation / etc. not included). My tests were performed on NVIDIA A10 GPUs, with pytorch version 1.12.1+cu116 and torchvision version 0.13.1+cu116.&lt;/p&gt;
&lt;details&gt; &lt;summary&gt;Code&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torchvision&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sys&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;numpy&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;cuda&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float16&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;EPOCHS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;24&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;MOMENTUM&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;WEIGHT_DECAY&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;5e-4&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;lr_knots&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;EPOCHS&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;EPOCHS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;lr_vals&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.6&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;H&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;CUTSIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;CROPSIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;CNN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;dims&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;256&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;c_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_channels&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_in&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MaxPool2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BatchNorm2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CELU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;alpha&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.075&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Sequential&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MaxPool2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Flatten&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

  &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cross_entropy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reduction&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'none'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label_smoothing&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;argmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;loadCIFAR10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torchvision&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CIFAR10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;root&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;./data&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;download&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;test&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torchvision&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CIFAR10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;root&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;./data&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;download&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std_mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;unbiased&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;keepdim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;permute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;getBatches&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;istrain&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;istrain&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;perm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randperm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;perm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;perm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;Crop&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;([(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CROPSIZE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CROPSIZE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)],&lt;/span&gt; 
        &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ReflectionPad2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CROPSIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[...,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;FlipLR&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;([(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,)],&lt;/span&gt; 
        &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;cutout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[...,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CUTSIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CUTSIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Cutout&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;([(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x0&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CUTSIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y0&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CUTSIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cutout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Crop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;FlipLR&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Cutout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;optioni&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optioni&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;==&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optioni&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;==&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;istrain&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loadCIFAR10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CNN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;momentum&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MOMENTUM&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weight_decay&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;WEIGHT_DECAY&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nesterov&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;training_time&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stepi&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;EPOCHS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;start_time&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;perf_counter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;train_losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_accs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;getBatches&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stepi&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_groups&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'lr'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;numpy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;interp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stepi&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;//&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr_knots&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr_vals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;acc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;train_losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;detach&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;train_accs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;acc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;detach&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;training_time&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;perf_counter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;start_time&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;eval&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;test_accs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;detach&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;getBatches&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;summary&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'epoch % 2d  train loss %.3f  train acc %.2f  test acc %.2f  training time %.2f'&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;summary&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_losses&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;summary&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_accs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;summary&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;test_accs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;training_time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;Why 94%? Because that’s reportedly &lt;a href=&quot;http://karpathy.github.io/2011/04/27/manually-classifying-cifar10/&quot;&gt;human level accuracy on CIFAR-10&lt;/a&gt;. Additionally, it made for a reasonable balance between accuracy, code complexity and training time.&lt;/p&gt;

&lt;p&gt;The code and idea is based on &lt;a href=&quot;https://colab.research.google.com/github/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb&quot;&gt;the final code&lt;/a&gt; from the fantastic blog series &lt;a href=&quot;https://myrtle.ai/learn/how-to-train-your-resnet/&quot;&gt;How to Train Your ResNet&lt;/a&gt; from &lt;a href=&quot;https://myrtle.ai/&quot;&gt;myrtle.ai&lt;/a&gt;. The blog series worked on optimizing the training time to reach 94% on CIFAR-10, and was able to get it down to 26 seconds!&lt;/p&gt;

&lt;p&gt;The code from myrtle.ai ended up being more than 500 lines long and containing a large number of clever tricks and optimizations.&lt;/p&gt;

&lt;h1 id=&quot;motivation&quot;&gt;Motivation&lt;/h1&gt;
&lt;p&gt;My main interest is in studying why deep learning works so well, specifically the &lt;a href=&quot;/2022/08/29/intro_implicit_bias.html&quot;&gt;implicit biases&lt;/a&gt; of deep learning. Ideally, I would like to study and do experiments on realistic and representative models and tasks. However, it’s not actually clear what is realistic and representative of deep learning. The best criterion seems to be “it works much better than all approaches not based on deep learning”.&lt;/p&gt;

&lt;p&gt;However, the best performing deep learning models are also the largest (taking days to train on hardware I don’t have access to) and full of tricks which make them impossible to analyze mathematically and hard to even reproduce. I am therefore interested in finding the simplest cases where deep learning works significantly better than alternative approaches. Image classification on CIFAR-10 is among the best I’ve found for these criteria.&lt;/p&gt;

&lt;p&gt;myrtle.ai’s model was a good starting point, as it reduced training time from hours to seconds, when compared to most alternatives. However, many of the tricks they used would complicates analysis while only providing a small gain in performance. I removed many of the less significant optimizations in favor of simplicity. Things like “frozen batch norm scales” and “exponential moving averages” were removed.&lt;/p&gt;

&lt;p&gt;You might think that reducing training time to seconds is overkill, since taking some minutes to train a model sounds fine. However, it is very useful to have some wiggle room in training time. As an example, the final accuracy varies from run to run because of randomness in the training procedure (data augmentation, initialization), so myrtle.ai typically did 50 runs just to reduce this variance.&lt;/p&gt;

&lt;h1 id=&quot;the-important-tricks&quot;&gt;The important tricks&lt;/h1&gt;
&lt;p&gt;While trying to simplify the code, I found the following tricks to be critical for performance.&lt;/p&gt;

&lt;h4 id=&quot;batch-normalization&quot;&gt;Batch normalization&lt;/h4&gt;
&lt;p&gt;Basically, everything diverges if you remove normalization layers in modern deep learning models. Without normalization layers, things like initialization size, learning rate and layer scaling become fragile hyperparameters. In general, it seems that even when all these (new) hyperparameters are tuned well, it’s too hard to compete with batch norm. The myrtle.ai blog &lt;a href=&quot;https://myrtle.ai/learn/how-to-train-your-resnet-7-batch-norm/&quot;&gt;came to a similar conclusion&lt;/a&gt;.&lt;/p&gt;

&lt;h4 id=&quot;data-augmentation&quot;&gt;Data augmentation&lt;/h4&gt;
&lt;p&gt;Without data augmentations, accuracies drop to around 90%, which is not significantly better non-deep learning approaches.&lt;/p&gt;

&lt;p&gt;I believe the current best non-deep learning methods for CIFAR-10 are convolutional kernel methods, which are typically based on linearizing neural network architectures. These can reach around 90% accuracy on CIFAR-10 (&lt;a href=&quot;https://arxiv.org/abs/1911.00809&quot;&gt;89%&lt;/a&gt;, &lt;a href=&quot;https://arxiv.org/abs/2003.02237&quot;&gt;90%&lt;/a&gt;), but struggle for large training sets. They need an \(N \times N\) Gram matrix, where \(N\) is the number of training samples. This means they cannot use data augmentation to the same degree as modern deep learning.&lt;/p&gt;

&lt;h4 id=&quot;learning-rate-schedule--warm-up&quot;&gt;Learning rate schedule / warm-up&lt;/h4&gt;
&lt;p&gt;I used this learning rate schedule:
&lt;img src=&quot;/images/cifar94_lr.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;h1 id=&quot;interesting-observations&quot;&gt;Interesting observations&lt;/h1&gt;

&lt;h4 id=&quot;residual-connections-were-not-necessary&quot;&gt;Residual connections were not necessary&lt;/h4&gt;
&lt;p&gt;I expected residual connections to be crucial for the network architecture. However, they were not necessary, so I removed them.&lt;/p&gt;

&lt;h4 id=&quot;output-scale-is-important&quot;&gt;Output scale is important&lt;/h4&gt;
&lt;p&gt;The code from myrtle.ai scales down the output of their model by a factor 8. This is very interesting to me, since this is a trick to increase the effect of what I call “implicit bias by small initialization”. I used the trick myself in previous blogs like &lt;a href=&quot;/2022/10/25/single_epoch.html&quot;&gt;1&lt;/a&gt; and &lt;a href=&quot;/2022/07/06/dln_classifier.html&quot;&gt;2&lt;/a&gt;.&lt;/p&gt;</content><author><name></name></author><summary type="html">The following code scores 94.02% test accuracy (mean of 40 runs) in 94 lines and less than 94 seconds training time (loading / validation / etc. not included). My tests were performed on NVIDIA A10 GPUs, with pytorch version 1.12.1+cu116 and torchvision version 0.13.1+cu116. Code import torch, torchvision, sys, time, numpy from torch import nn import torch.nn.functional as F</summary></entry><entry><title type="html">Implicit bias in single epoch SGD</title><link href="https://johanwind.github.io/2022/10/25/single_epoch.html" rel="alternate" type="text/html" title="Implicit bias in single epoch SGD" /><published>2022-10-25T00:00:00+00:00</published><updated>2022-10-25T00:00:00+00:00</updated><id>https://johanwind.github.io/2022/10/25/single_epoch</id><content type="html" xml:base="https://johanwind.github.io/2022/10/25/single_epoch.html">&lt;p&gt;Large transformer models are typically trained using only a single epoch, i.e the model only sees each data point once. Theoretically, this training regime circumvents problems related to overfitting and generalization by converting them to a question of how fast we can converge given noisy gradients.&lt;/p&gt;

&lt;p&gt;My recent experiments indicate that the implicit bias by small initialization is important in this regime, also in quite realistic settings. This makes me hopeful that the single epoch training SGD regime is a good setting to study several important phenomena observed in deep learning. In this blog I will showcase single epoch training on the following matrix sensing task.&lt;/p&gt;

&lt;h2 id=&quot;matrix-sensing-task&quot;&gt;Matrix sensing task&lt;/h2&gt;
&lt;p&gt;Our task is to estimate a \(d\times d\) ground truth matrix \(T\). The effect we are studying in this blog is more apparent at larger scales, so we pick the fairly large dimension \(d \in \{256,1024\}\). We generate \(T\) as a random matrix of rank \(r \in \{2,5\}\).&lt;/p&gt;

&lt;p&gt;The \(i\)th training example consists of a \(d\)-dimensional input vector \(a_i\), a \(d\)-dimensional vector \(b_i\) of output weights, and a scalar target value \(y_i = a_i^\top T b_i\).&lt;/p&gt;

&lt;p&gt;We let \(a_i\) and \(b_i\) be random Gaussians \(a_i, b_i \sim \mathcal{N}(0,I)\). To make \(T\) of rank \(r\), we generate it as the product of two random Gaussian matrices of sizes \(d\times r\) and \(r \times d\).&lt;/p&gt;

&lt;h2 id=&quot;simple-setting&quot;&gt;Simple setting&lt;/h2&gt;
&lt;p&gt;In this section, I will describe one of the simplest settings I know where the implicit bias is apparent. Then, in later sections I will show more realistic settings.&lt;/p&gt;

&lt;p&gt;Our simple model will be a 2 layer linear network (as in &lt;a href=&quot;/2022/07/06/dln_classifier.html&quot;&gt;this previous post&lt;/a&gt;). The parameters are two \(d\times d\) matrices \(W_1\) and \(W_2\). The model outputs \(\mathrm{predict}(a_i,b_i) = a_i^\top W_1 W_2 b_i\). We use the squared sample loss \(L_i = \frac{1}{2}(y_i-\mathrm{predict}(a_i,b_i))^2\). The parameters are initialized using Xavier normal initialization (random Gaussians with scale \(\frac{1}{\sqrt d}\)).&lt;/p&gt;

&lt;p&gt;We train the model by SGD with constant learning rate \(\eta\), for &lt;em&gt;a single epoch&lt;/em&gt;. This means we loop through the samples one by one, for each taking a step of length \(\eta\) along the negative gradient of \(L_i\). Note that we need a large enough learning rate \(\eta\) to be able to converge after only one epoch. A too small learning rate will keep the weights too close to the initialization, making learning impossible. However, a too large learning rate will make the model diverge.&lt;/p&gt;

&lt;p&gt;We will be exploiting &lt;a href=&quot;/2022/07/06/dln_classifier.html&quot;&gt;implicit bias by small initialization&lt;/a&gt;. Actually, instead of scaling down the initialization, we will scale up the target values, which is equivalent (for homogeneous models like ours). Therefore, we add a parameter \(\gamma\) which sets the scale of the ground truth matrix \(\gamma T\).&lt;/p&gt;

&lt;p&gt;We pick \(d = 256\), \(r = 2\), \(\eta = 2\times 10^{-8}\) and \(\gamma = 100\).&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code for simple setting&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;manual_seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;256&lt;/span&gt;      &lt;span class=&quot;c1&quot;&gt;# Input and output dimension
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# Rank of ground truth matrix
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;2e-8&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# Learning rate
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;yscale&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Scale factor for ground truth
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;yscale&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;W1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;W2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;log&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;//&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Gradient descent
&lt;/span&gt;      &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;error&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Relative error
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;error&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;show&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;&lt;img src=&quot;/images/simple_online_bias.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Without exploiting the rank constraint, and using only \(\frac{1}{2}d^2\) samples, we would expect a relative reconstruction error (Frobenius norm of \(W_1W_2-T\) over Frobenius norm of \(T\)) of around \(0.5\). However, we only ran the optimization long enough to see \(\frac{1}{2}d^2\) samples, and as seen in the plot above we already converged and reconstructed the correct matrix!&lt;/p&gt;

&lt;h2 id=&quot;more-realistic-setting&quot;&gt;More realistic setting&lt;/h2&gt;
&lt;p&gt;What interests me about the setting above, i.e single epoch SGD with low rank linear ground truth, is that we can observe the same phenomena in much more realistic settings. We can add more layers, non-linearities, even normalization layers, and we still see the same phenomena!&lt;/p&gt;

&lt;p&gt;The first “realistic” setup is what I call “the CNN setup”. Typical CNNs (like ResNet and PyramidNet) use ReLU(-like) activation functions, batch normalization, and are optimized by SGD with momentum. Our model consists of 10 layers with ReLU non-linearities and Batch Normalization, optimized by SGD with momentum.&lt;/p&gt;

&lt;p&gt;The second “realistic” setup, I call “the transformer setup”. Transformers typically use ReLU-like activations, layer normalization, and are optimized by the Adam optimizer. We stack 10 layers with ReLU activations and layer normalization, and optimize by Adam with default parameters.&lt;/p&gt;

&lt;p&gt;We also scale up the experiment a bit to \(d = 1024\), \(r = 5\) and use mini-batches of size \(256\) to speed up the computations.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code for more realistic setting&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tqdm&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tqdm&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Optional loading bar
&lt;/span&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;F&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;functional&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;manual_seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'cuda'&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;#device = 'cpu'
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;      &lt;span class=&quot;c1&quot;&gt;# Input and output dimension
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;         &lt;span class=&quot;c1&quot;&gt;# Rank of ground truth matrix
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;256&lt;/span&gt;       &lt;span class=&quot;c1&quot;&gt;# Batch size
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;L&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# Number of layers
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;yscale&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;4e-2&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Scale factor for ground truth
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;truth&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;yscale&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;cnn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Sequential&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;L&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BatchNorm1d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ReLU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()]][:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sgd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cnn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;2.5e-8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;momentum&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;transformer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Sequential&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;L&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ReLU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;LayerNorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]][:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;adam&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Adam&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;6e-4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;label&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cnn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sgd&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'CNN'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adam&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'Transformer'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]:&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;log&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bi&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tqdm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;//&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;    &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'bi,bi-&amp;gt;b'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;truth&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'bi,bi-&amp;gt;b'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pred&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bi&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;item&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;label&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;legend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;show&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;&lt;img src=&quot;/images/realistic_online_bias.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;h2 id=&quot;interesting-phenomena&quot;&gt;Interesting phenomena&lt;/h2&gt;
&lt;p&gt;While performing experiments like the ones above, I recognized numerous phenomena seen elsewhere in deep learning. This makes me hopeful that this simple setting is a good place to study and understand phenomena which are otherwise hard to specify and analyze.&lt;/p&gt;

&lt;h3 id=&quot;small-initialization--large-targets&quot;&gt;Small initialization / large targets&lt;/h3&gt;
&lt;p&gt;Implicit bias by small initialization seems crucial to the experiments. If the initialization is too large compared to the ground truth (so \(\gamma\) too small), all the networks seem unable to train quickly and generalize. Interestingly, while the simple settings require very large \(\gamma\), the more realistic settings (more layers, normalization layers) work best with smaller values of \(\gamma\). When doing classification instead of regression, the standard classification losses (say cross entropy with label smoothing) have a slightly larger target scale by default. That could mean we don’t need any explicit target scaling. We saw this effect &lt;a href=&quot;/2022/07/06/dln_classifier.html&quot;&gt;in a previous blog post&lt;/a&gt;.&lt;/p&gt;

&lt;h3 id=&quot;only-works-well-for-large-instances&quot;&gt;Only works well for large instances&lt;/h3&gt;
&lt;p&gt;Deep learning is very data hungry, with simpler methods beating it when data is scarce. The single epoch training regime might shed some light on why that is, as it also only works for large instances.&lt;/p&gt;

&lt;h3 id=&quot;learning-rate-warm-up&quot;&gt;Learning rate warm-up&lt;/h3&gt;
&lt;p&gt;In practical deep learning, “learning rate warm-up” is an important trick where we gradually increase the learning rate at the start of training. By using this trick, I was able to use larger learning rates later without diverging. For example, we can add &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;opt.param_groups[0]['lr'] = 1e-7*min(1,0.1+bi/(d**2/4))&lt;/code&gt; before &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;opt.step()&lt;/code&gt; in the CNN setting.&lt;/p&gt;

&lt;h3 id=&quot;normalization-layers&quot;&gt;Normalization layers&lt;/h3&gt;
&lt;p&gt;Normalization layers seem to help train deeper networks. However, it is not well understood why. Maybe analyzing the effect of normalization layers in this simple setting can give some insights.&lt;/p&gt;

&lt;h2 id=&quot;theory&quot;&gt;Theory&lt;/h2&gt;
&lt;p&gt;The theory of online optimization is well suited to describe single epoch SGD. As a brief introduction to online optimization, I will give a simple proof of how single epoch SGD optimizes convex functions. The proof is mostly taken from “Convex Optimization: Algorithms and Complexity” by Sébastien Bubeck.&lt;/p&gt;

&lt;p&gt;We want to optimize the differentiable (or we could work with subgradients), convex function \(\tilde f \colon \mathbb{R}^d \to \mathbb{R}\). To apply SGD, we sample from some family of differentiable functions \(f \colon \mathbb{R}^d \to \mathbb{R}\) satisfying \(\mathbb{E}\|\nabla f\|_2^2 \le B^2\) and \(\mathbb{E}\nabla f = \nabla \tilde f\). Next, we assume there exists some minimizer \(x^* \in \mathbb{R}\) of \(\tilde f\), and we have some initial guess \(x^1\). The SGD update rule is&lt;/p&gt;

\[x^{k+1} = x^k-\eta\nabla f^k(x^k),\]

&lt;p&gt;where \(f^k\) is randomly sampled.&lt;/p&gt;

&lt;p&gt;By convexity \(\tilde f(x^k) - \tilde f(x^*) \le \nabla \tilde f(x^k)^\top(x^k-x^*)\). So&lt;/p&gt;

\[\begin{align*}
\mathbb{E}\min_{k \in \{1,\dots,K\}} \tilde f(x^k) - \tilde f(x^*) 
\le \frac{1}{K} \mathbb{E}\sum_{k=1}^K \tilde f(x^k) - \tilde f(x^*)\\
\le \frac{1}{K} \mathbb{E}\sum_{k=1}^K \nabla \tilde f(x^k)^\top (x^k-x^*)
= \frac{1}{K} \mathbb{E}\sum_{k=1}^K \nabla f^k(x^k)^\top (x^k-x^*)
\end{align*}\]

&lt;p&gt;Using \(x^{k+1}-x^k = -\eta\nabla f^k(x^k)\) and \(2u^\top v = \|u\|_2^2+\|v\|_v^2-\|u-v\|_2^2\), we calculate&lt;/p&gt;

\[\begin{align*}
\nabla f^k(x^k)^\top (x^k-x^*)
&amp;amp;= \frac{1}{\eta}(x^k-x^{k+1})^\top (x^k-x^*)\\
&amp;amp;= \frac{1}{2\eta}(\|x^k-x^{k+1}\|_2^2+\|x^k-x^*\|_2^2-\|x^*-x^{k+1}\|_2^2)\\
&amp;amp;= \frac{1}{2\eta}(\|x^k-x^*\|_2^2-\|x^{k+1}-x^*\|_2^2)+\frac{\eta}{2}\|\nabla f(x^k)\|_2^2
\end{align*}\]

&lt;p&gt;Inserting into the previous expression, we get a telescoping sum. We pick \(\eta = \frac{\|x^1-x^*\|_2}{B\sqrt K}\) to minimize the final bound. Hence&lt;/p&gt;

\[\begin{align*}
&amp;amp;\frac{1}{K} \mathbb{E}\sum_{k=1}^K \nabla f^k(x^k)^\top (x^k-x^*)\\
= &amp;amp;\frac{1}{2K\eta} \mathbb{E}\left(\|x^1-x^*\|_2^2-\|x^{K+1}-x^*\|_2^2\right) + \frac{\eta}{2K}\mathbb{E}\sum_{k=1}^K \|\nabla f(x^k)\|_2^2\\
\le &amp;amp;\frac{1}{2K\eta} \|x^1-x^*\|_2^2 + \frac{\eta B^2}{2}
= \frac{\|x^1-x^*\|_2B}{\sqrt K}.
\end{align*}\]

&lt;p&gt;In summary, \(\mathbb{E}\min_{k \in \{1,\dots,K\}} \tilde f(x^k) - \tilde f(x^*) \le \frac{\|x^1-x^*\|_2B}{\sqrt K}\). This shows that in expectation, we successfully optimize the function with rate \(\frac{1}{\sqrt K}\).&lt;/p&gt;</content><author><name></name></author><summary type="html">Large transformer models are typically trained using only a single epoch, i.e the model only sees each data point once. Theoretically, this training regime circumvents problems related to overfitting and generalization by converting them to a question of how fast we can converge given noisy gradients.</summary></entry><entry><title type="html">Start here: Why I care about implicit biases</title><link href="https://johanwind.github.io/2022/08/29/intro_implicit_bias.html" rel="alternate" type="text/html" title="Start here: Why I care about implicit biases" /><published>2022-08-29T00:00:00+00:00</published><updated>2022-08-29T00:00:00+00:00</updated><id>https://johanwind.github.io/2022/08/29/intro_implicit_bias</id><content type="html" xml:base="https://johanwind.github.io/2022/08/29/intro_implicit_bias.html">&lt;p&gt;My research, and this blog, are centered around implicit biases in deep learning. But what even are “implicit biases”, and why do I care about them? In this post, I try to explain my motivations, why I think understanding implicit biases is the key to unlocking the potential of deep learning.&lt;/p&gt;

&lt;h1 id=&quot;why-i-care-about-deep-learning&quot;&gt;Why I care about deep learning&lt;/h1&gt;
&lt;p&gt;It works &lt;em&gt;really&lt;/em&gt; well for some problems. Just look at the following image generated using deep learning (by &lt;a href=&quot;https://www.midjourney.com/showcase&quot;&gt;Midjourney&lt;/a&gt;):&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://www.creativeshrimp.com/wp-content/uploads/2022/06/midjourney_aiart_gleb_alexandrov_06-1170x731.jpg&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Making a program to automatically draw a beautiful image from a text prompt would be practically impossible, before deep learning came along and did it. Other problems where deep learning is miles ahead of the competition include &lt;a href=&quot;https://paperswithcode.com/sota/image-classification-on-imagenet&quot;&gt;image classification&lt;/a&gt;, &lt;a href=&quot;https://www.deepmind.com/publications/playing-atari-with-deep-reinforcement-learning&quot;&gt;playing games from pixels&lt;/a&gt; and &lt;a href=&quot;http://www.mattmahoney.net/dc/text.html&quot;&gt;lossless text compression&lt;/a&gt;. I believe deep learning will continue to deliver breakthroughs in other areas, and I’m excited to see what those are.&lt;/p&gt;

&lt;h1 id=&quot;the-problem-with-deep-learning&quot;&gt;The problem with deep learning&lt;/h1&gt;
&lt;p&gt;Ok, so deep learning has amazing potential, what are the problems we need to overcome to achieve that potential? In my opinion, the main problem with modern deep learning is the huge amount of engineering effort it currently requires.&lt;/p&gt;

&lt;p&gt;To get the best practical performance from deep learning, you need to add a bunch of small (but hugely important) tricks. These tricks take time to learn, and take time to apply and adjust to new problems. An example of such a trick is using &lt;em&gt;data augmentation&lt;/em&gt; when training an image classifier. Concretely, a simple data augmentation would be adding horizontally flipped images to your dataset, effectively doubling the size of your dataset.&lt;/p&gt;

&lt;p&gt;Another problem is that modern deep learning requires enormous computations. More compute gives better results, so naturally you put in as much compute as you can afford. In practice, this means waiting hours or days for a model to train. This drastically increases the time required to test new code, and in general slows down development.&lt;/p&gt;

&lt;p&gt;Tuning the &lt;em&gt;hyperparameters&lt;/em&gt; is also a time-consuming part of modern deep learning. The performance of a deep learning model critically depends on numerous of tuning parameters, which need to be carefully chosen when applying the model to a new problem. Here is a list of some common hyperparameters:&lt;/p&gt;

&lt;table&gt;
  &lt;thead&gt;
    &lt;tr&gt;
      &lt;th&gt;Hyperparameter&lt;/th&gt;
      &lt;th style=&quot;text-align: right&quot;&gt;Typical value&lt;/th&gt;
      &lt;th style=&quot;text-align: right&quot;&gt;Affects model expressivity&lt;/th&gt;
    &lt;/tr&gt;
  &lt;/thead&gt;
  &lt;tbody&gt;
    &lt;tr&gt;
      &lt;td&gt;Learning rate&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;3e-4&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Momentum&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;0.9&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Learning rate schedule&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;Cosine&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Optimizer&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;Adam&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Batch size&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;32&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Number of training epochs&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;300&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Weight decay&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;0&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Activation function&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;ReLU&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;Yes&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Weight initialization&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;Xavier&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Feature dimension&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;512&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;Yes&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Label smoothing&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;0.1&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;Dropout probability&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;0.2&lt;/td&gt;
      &lt;td style=&quot;text-align: right&quot;&gt;No&lt;/td&gt;
    &lt;/tr&gt;
  &lt;/tbody&gt;
&lt;/table&gt;

&lt;p&gt;In practice, these hyperparameters are chosen using a combination of the engineer’s experience, and repeatedly testing the model with different hyperparameter configurations. Note that changing one parameter might change the effects of other parameters, making it exponentially harder. Recall that testing the model is computationally expensive and slow. The combination makes for a painful engineering experience. The “correct” solution is to run an automated search through many hyperparameter combinations, and pick the best. However, that is computationally expensive. So you would generally rather train a more computationally expensive model giving better results.&lt;/p&gt;

&lt;p&gt;I believe the applicability of deep learning is severely limited by the huge amount of engineering effort it requires. So what is the solution?&lt;/p&gt;

&lt;h1 id=&quot;the-need-for-theory&quot;&gt;The need for theory&lt;/h1&gt;
&lt;p&gt;Let’s compare with classical machine learning algorithms, things like linear regression (on possibly nonlinear, hand-engineered features). Applying those algorithms can often be reduced to minimizing some convex loss function \(L\) plus some convex regularizer \(R\) times some scalar weight \(\lambda\),&lt;/p&gt;

\[\min_\theta L(\theta)+\lambda R(\theta).\]

&lt;p&gt;We can then use some numerical optimization algorithm to find the unique minimum at parameter configuration \(\theta^*\), and use those parameters to produce predictions. There are several hyperparameters in the optimization algorithm, but those only affect the time to convergence, so they can be left to reasonable default values. Mathematicians proved correctness of the optimization algorithms, so the right \(\theta^*\) is found every time. This mature theory means practitioners can focus most of their effort on making good models, since then applying the models is relatively straightforward.&lt;/p&gt;

&lt;p&gt;But modern deep learning is also optimizing a loss function? We have optimizers which can be guaranteed to find (local) minima. What’s the problem? The problems start when we realize that modern deep learning methods can have way more parameters than the number of data points they are trying to fit. Even for small datasets like CIFAR with 50 000 training images, deep learning models use millions of parameters. The models are &lt;em&gt;overparameterized&lt;/em&gt;. &lt;a href=&quot;https://arxiv.org/abs/1611.03530&quot;&gt;Deep learning models can fit random training labels&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;Because of overparameterization, there are &lt;em&gt;many&lt;/em&gt; different parameter configurations giving 0 training loss (or arbitrarily small loss in the case of cross entropy classification loss). To make matters worse, deep learning models don’t seem to require explicit regularizers (specifically, weight decay is optional). As a result, in deep learning, our training algorithm and hyperparameters &lt;em&gt;do&lt;/em&gt; affect what model we end up with. I call it the &lt;em&gt;implicit bias&lt;/em&gt; which determines the final model, among the many models minimizing the loss function. The long list of hyperparameters in the table above can (and empirically does!) affect the implicit bias and performance of the final model.&lt;/p&gt;

&lt;p&gt;Classically, the main important aspect of a model is what kind of functions it can express, its &lt;em&gt;expressivity&lt;/em&gt;. However, in deep learning, most of the important choices don’t even affect the expressivity, they only affect the implicit bias (see the table above). Sadly, we have no better description of this implicit bias than “run exactly this algorithm with these hyperparameters, that should give you the implicit bias baked into your final model”.&lt;/p&gt;

&lt;p&gt;As a concrete example of why this is a problem, say you made a fantastic new second order optimization algorithm, it optimizes the loss function 10x faster than standard first order methods! The current deep learning models were developed and compared under the implicit bias given by current optimizers. Chances are that your second order optimizer significantly changes that implicit bias, giving worse performance, since the model was not built for the new implicit bias. The algorithm couldn’t be used.&lt;/p&gt;

&lt;p&gt;Let’s say we found a way to better characterize the implicit bias of modern deep learning models. Maybe we could improve it. Maybe we could train models faster, without losing out on performance. Maybe we could get rid of annoying hyperparameters. I want to find out.&lt;/p&gt;</content><author><name></name></author><summary type="html">My research, and this blog, are centered around implicit biases in deep learning. But what even are “implicit biases”, and why do I care about them? In this post, I try to explain my motivations, why I think understanding implicit biases is the key to unlocking the potential of deep learning.</summary></entry><entry><title type="html">Technical: Deep Linear Networks with label noise minimize the nuclear norm</title><link href="https://johanwind.github.io/2022/08/10/dln_reg.html" rel="alternate" type="text/html" title="Technical: Deep Linear Networks with label noise minimize the nuclear norm" /><published>2022-08-10T00:00:00+00:00</published><updated>2022-08-10T00:00:00+00:00</updated><id>https://johanwind.github.io/2022/08/10/dln_reg</id><content type="html" xml:base="https://johanwind.github.io/2022/08/10/dln_reg.html">&lt;p&gt;The implicit biases in (stochastic) gradient descent with large learning rates often reduce to regularizers preferring “flat minima”. See for example &lt;a href=&quot;/2022/07/22/noisy_reg.html&quot;&gt;my blog with label noise&lt;/a&gt;. The regularizers depend on the network architecture, so I find it interesting to characterize them for tractable cases such as Deep Linear Networks.&lt;/p&gt;

&lt;p&gt;This blog is more technical, mathy and less standalone than previous blogs. However, to my knowledge only the two layer case has been analyzed at the time of writing, so I thought I would put this more general result out there.&lt;/p&gt;

&lt;h2 id=&quot;deep-linear-networks-dlns&quot;&gt;Deep Linear Networks (DLNs)&lt;/h2&gt;
&lt;p&gt;We already encountered DLNs in a &lt;a href=&quot;/2022/07/06/dln_classifier.html&quot;&gt;previous blog post&lt;/a&gt;. However, we will introduce notation useful for this post. Let’s consider an \(L\)-layer DLN with dimensions \(d_0,d_1,\dots,d_L\). I.e we are parameterizing a \(d_0 \times d_L\) matrix \(\tilde W\) as&lt;/p&gt;

\[\tilde W = W_1 W_2 \cdots W_L\]

&lt;p&gt;where \(W_i\) is a \(d_{i-1}\times d_i\) matrix. We will assume \(L \ge 2\) and \(d_i \ge \min\{d_0,d_L\}\), so the model is overparameterized.&lt;/p&gt;

&lt;p&gt;The regularizer we will analyze is inspired by the following slight generalization of matrix completion. Given a \(d_0\)-dimensional vector \(x\) and a \(d_L\) dimensional vector \(y\) our model predicts \(P(x,y) = x^\top \tilde W y\). For a set of \(n\) data points \((x_i,y_i)\) with targets \(t_i\) we can then do regression with the loss \(\sum_{i=1}^n (P(x_i,y_i)-t_i)^2\). If \(x_i\) and \(y_i\) are standard unit vectors, we get a formulation of matrix completion.&lt;/p&gt;

&lt;p&gt;We can then use the same arguments as in &lt;a href=&quot;/2022/07/22/noisy_reg.html&quot;&gt;the first blog post on implicit regularization by large learning rate&lt;/a&gt; to find the regularizer introduced by label noise and large learning rate, namely&lt;/p&gt;

\[R(W_1,\dots,W_L) = \sum_{i=1}^n \sum_{l=1}^L \|\nabla_{W_l} P(x_i,y_i)\|_F^2\]

&lt;p&gt;where \(\|\cdot\|_F\) is the Frobenius norm. To summarize roughly: among all the solutions \(W_1,\dots,W_L\) minimizing the loss, gradient descent with label noise will pick the one minimizing \(R(W_1,\dots,W_L)\).&lt;/p&gt;

&lt;p&gt;The rest of the blog post will try to analyze the regularizer \(R\). We will eventually see that it is similar to the nuclear norm regularizer \(\|\tilde W\|_*\), which is well known in matrix completion and is known to encourage low rank.&lt;/p&gt;

&lt;p&gt;As a first step we differentiate \(P\) and split the resulting Frobenius norm to get&lt;/p&gt;

\[\sum_{i=1}^n \sum_{l=1}^L \|\nabla_{W_l} P(x_i,y_i)\|_F^2 = \sum_{i=1}^n \sum_{l=1}^L \|x_i^\top W_1\cdots W_{l-1}\|_2^2 \|W_{l+1}\cdots W_L y_i\|_2^2.\]

&lt;h3 id=&quot;assumption&quot;&gt;Assumption&lt;/h3&gt;
&lt;p&gt;To get a nice expression for the regularizer, we will make the following assumption: There exists matrices \(X\) and \(Y\) such that&lt;/p&gt;

\[\sum_{i=1}^n\ (x_i x_i^\top) \otimes (y_i y_i^\top) = (X^\top X) \otimes (Y Y^\top)\]

&lt;p&gt;where \(\otimes\) is the Kronecker product. We will furthermore assume that \(X\) and \(Y\) are square, invertible matrices. I expect the results in this blog to also hold for singular \(X\) and \(Y\), but that introduces some technical difficulties.&lt;/p&gt;

&lt;p&gt;In my intuition, \(X^\top X\) and \(Y Y^\top\) are basically the covariances of the vectors \(\{x_i\}_i\) and the vectors \(\{y_i\}_i\). The assumption is an independence assumption between \(\{x_i\}_i\) and \(\{y_i\}_i\), it is pretty much saying \(\mathbb{E}\big((x x^\top) \otimes (y y^\top)\big) = \mathbb{E}(x x^\top)\otimes \mathbb{E}(y y^\top)\).&lt;/p&gt;

&lt;p&gt;While this assumption is often not exactly satisfied, I hope it is approximately satisfied when \(x_i\) and \(y_i\) are sampled independently, such as usually done in matrix completion. My attempts at \(L &amp;gt; 2\) without the assumption ended in ugly tensor norms which were hard to work with and interpret.&lt;/p&gt;

&lt;p&gt;Using the assumption, we may simplify&lt;/p&gt;

\[\sum_{i=1}^n \sum_{l=1}^L \|x_i^\top W_1\cdots W_{l-1}\|_2^2 \|W_{l+1}\cdots W_L y_i\|_2^2 = \sum_{l=1}^L \|X\ W_1\cdots W_{l-1}\|_F^2 \|W_{l+1}\cdots W_L Y\|_F^2.\]

&lt;h2 id=&quot;lower-bound&quot;&gt;Lower bound&lt;/h2&gt;
&lt;p&gt;In this section we will use some inequalities to find a tight lower bound on \(R\). Using the AM–GM inequality we have&lt;/p&gt;

\[\sum_{l=1}^L \|X\ W_1\cdots W_{l-1}\|_F^2 \|W_{l+1}\cdots W_L Y\|_F^2 \ge L \left(\prod_{l=1}^L \|X\ W_1\cdots W_{l-1}\|_F^2 \|W_{l+1}\cdots W_L Y\|_F^2\right)^{\frac{1}{L}}.\]

&lt;p&gt;Rearrange the product&lt;/p&gt;

\[\begin{align}
&amp;amp;L \left(\prod_{l=1}^L \|X\ W_1\cdots W_{l-1}\|_F^2 \|W_{l+1}\cdots W_L Y\|_F^2\right)^{\frac{1}{L}}\\
=\ &amp;amp;L \left(\|X\|_F\|Y\|_F\prod_{l=1}^{L-1} \|X\ W_1\cdots W_l\|_F \|W_{l+1}\cdots W_L Y\|_F\right)^{\frac{2}{L}}.
\end{align}\]

&lt;p&gt;Use \(\|A\|_F\|B\|_F \ge \|AB\|_*\) where \(\|\cdot\|_*\) is the nuclear norm&lt;/p&gt;

\[\begin{align}
&amp;amp;L \left(\|X\|_F\|Y\|_F\prod_{l=1}^{L-1} \|X\ W_1\cdots W_l\|_F \|W_{l+1}\cdots W_L Y\|_F\right)^{\frac{2}{L}}\\
\ge\ &amp;amp;L \left(\|X\|_F\|Y\|_F\prod_{l=1}^{L-1} \|X\ W_1\cdots W_L Y\|_*\right)^{\frac{2}{L}}\\
=\ &amp;amp;L\left(\|X\|_F\|Y\|_F\right)^{\frac{2}{L}}\|X\tilde W Y\|_*^{2-\frac{2}{L}}.
\end{align}\]

&lt;h2 id=&quot;construction-achieving-the-bound&quot;&gt;Construction achieving the bound&lt;/h2&gt;
&lt;p&gt;In this section we will give a construction which achieves the lower bound, showing that it is tight. Let’s say we are given some \(\tilde W\) and need to find \(W_1 \cdots W_L = \tilde W\) such that \(R(W_1,\dots,W_L)\) is minimized.&lt;/p&gt;

&lt;p&gt;Let’s take the (compact) singular value decomposition of \(X \tilde W Y = U \Sigma V^\top\). Here \(U\) and \(V\) are semi-orthogonal matrices with dimensions \(d_0 \times r\) and \(d_L \times r\) and \(\Sigma\) is a diagonal \(r \times r\) matrix with positive diagonal, where \(r\) is the rank of \(\tilde W\) (which is also the rank of \(X \tilde W Y\) since \(X\) and \(Y\) are invertible). We can now construct&lt;/p&gt;

\[\begin{align}
W_1 &amp;amp;= \alpha \begin{pmatrix}X^{-1}U\Sigma^{\frac{1}{2}} &amp;amp; \bf 0\end{pmatrix}\\
W_i &amp;amp;= \beta \begin{pmatrix}I_r &amp;amp; \bf 0 \\ \bf 0 &amp;amp; \bf 0\end{pmatrix} \hspace{2.5cm} i = 2,\dots,L-1\\
W_L &amp;amp;= \gamma \begin{pmatrix}\Sigma^{\frac{1}{2}}V^\top Y^{-1} \\ \bf 0\end{pmatrix}
\end{align}\]

&lt;p&gt;where \(\beta = \left(\frac{\|X\tilde W Y\|_*}{\|X\|_F\|Y\|_F}\right)^{\frac{1}{L}}\), \(\alpha = \sqrt{\frac{\|X\|_F}{\|Y\|_F}}\beta^{1-\frac{L}{2}}\) and \(\gamma = \sqrt{\frac{\|Y\|_F}{\|X\|_F}}\beta^{1-\frac{L}{2}}\). \(\bf 0\) are zero matrices of appropriate dimensions to match the dimensions of the left-hand side. \(I_r\) is the \(r\times r\) identity matrix.&lt;/p&gt;

&lt;p&gt;It can be verified that \(W_1 \cdots W_L = \tilde W\) and&lt;/p&gt;

\[\sum_{l=1}^L \|X\ W_1\cdots W_{l-1}\|_F^2 \|W_{l+1}\cdots W_L Y\|_F^2 = L\left(\|X\|_F\|Y\|_F\right)^{\frac{2}{L}}\|X\tilde W Y\|_*^{2-\frac{2}{L}}.\]

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;We showed that the minimizing assignment of \(W_1,\dots,W_L\) give the regularizer&lt;/p&gt;

\[R(W_1,\dots,W_L) = L\left(\|X\|_F\|Y\|_F\right)^{\frac{2}{L}}\|X\tilde W Y\|_*^{2-\frac{2}{L}}.\]

&lt;p&gt;Since multiplication by constants and positive powers don’t change the minimum, we are basically minimizing \(\|X\tilde W Y\|_*\). For isotropic data we have \(X \propto I\) and \(Y \propto I\), so we get the usual nuclear norm regularizer \(\|\tilde W\|_*\). The regularizer \(\|X\tilde W Y\|_*\) can be interpreted as first normalizing the distributions of \(x_i\) and \(y_i\) to be isotropic, and then using the usual nuclear norm regularizer on the processed data. Specifically, let \(\tilde W' = X\tilde W Y\) and consider a loss function \(f\) depending only on the predictions \(P(x_i,y_i) = x_i^\top \tilde W y_i\). Then the change of variables to the loss function and regularizer yields&lt;/p&gt;

\[f(\{x_i^\top\tilde W y_i\}_i) + \lambda \|X\tilde W Y\|_* = f(\{(x_i^\top X^{-1})\tilde W' (Y^{-1}y_i)\}_i) + \lambda \|\tilde W'\|_*.\]</content><author><name></name></author><summary type="html">The implicit biases in (stochastic) gradient descent with large learning rates often reduce to regularizers preferring “flat minima”. See for example my blog with label noise. The regularizers depend on the network architecture, so I find it interesting to characterize them for tractable cases such as Deep Linear Networks.</summary></entry><entry><title type="html">Implicit bias by large learning rate: Noise can be helpful for gradient descent</title><link href="https://johanwind.github.io/2022/07/22/noisy_reg.html" rel="alternate" type="text/html" title="Implicit bias by large learning rate: Noise can be helpful for gradient descent" /><published>2022-07-22T00:00:00+00:00</published><updated>2022-07-22T00:00:00+00:00</updated><id>https://johanwind.github.io/2022/07/22/noisy_reg</id><content type="html" xml:base="https://johanwind.github.io/2022/07/22/noisy_reg.html">&lt;p&gt;I will demonstrate how large learning rates can lead to implicit biases in a simple regression task. Code to reproduce the results can be found in spoilers. We will use pytorch.&lt;/p&gt;

&lt;h2 id=&quot;the-regression-task&quot;&gt;The regression task&lt;/h2&gt;
&lt;p&gt;We generate normally distributed data points in \(d = 30\) dimensions. The regression target is the square of the first feature dimension. The other 29 dimensions are only there to make the task harder. We generate \(n = 200\) data points for training and another \(200\) data points for testing.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to generate data&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;manual_seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;30&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Input dimension
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;200&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Training examples / testing points
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Generate random points
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# Ground truth
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,:],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Train/test split
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:,:],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;&lt;img src=&quot;/images/noisy_reg_data.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;h2 id=&quot;the-neural-network&quot;&gt;The neural network&lt;/h2&gt;
&lt;p&gt;We will use a single hidden layer neural network with a quadratic activation function. The hidden layer will have \(m = 100\) nodes. The parameters of this network are a \(d \times m\) matrix \(A\) and a \(m\)-dimensional vector \(b\). For a \(d\)-dimensional data point \(X\), the model predicts&lt;/p&gt;

\[\text{predict}(X) = \sum_{i=1}^m b_i (A_i \cdot X)^2.\]

&lt;p&gt;This neural network should be well suited to solve the regression task since it is straightforward to find values for the parameters that solve the task exactly, for example we may pick \(A_{11} = b_1 = 1\) and set everything else to zero.&lt;/p&gt;

&lt;p&gt;We initialize \(A\) by Xavier/Glorot normal initialization and \(b\) to zero, i.e \(A_{ij} \sim \mathcal{N}(0,\frac{2}{d+m})\) and \(b_i = 0\).&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to make the neural network&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;makeModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;m&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Number of hidden nodes
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;m&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;m&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;manual_seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;To fit the model we run gradient descent on the Mean Square Error (MSE) over the training data. We train until the MSE on the training data is below \(10^{-4}\).&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to train the neural network&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e100&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Gradient descent
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Return test MSE
&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'MSE = %.2f'&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;makeModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;We choose the learning rate 0.01 and train the neural network. It scores MSE = 0.40 on the test set. This is quite bad, as we clearly did not solve the task. As a reference for the scale of MSE, the model always predicting 1 gets MSE = 1.97 .&lt;/p&gt;

&lt;h2 id=&quot;increasing-learning-rate&quot;&gt;Increasing learning rate&lt;/h2&gt;
&lt;p&gt;In deep learning, performance is often very dependent on hyperparameters. Let’s look at the effect of the learning rate.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to compare different learning rates&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;lr_list&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.001&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.15&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.005&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;mse_list&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;mse&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;makeModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;MSE = %.2f&quot;&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;&lt;img src=&quot;/images/noisy_reg_lr.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;For learning rates above 0.096 the optimization diverges, so we can’t go higher. We see that all learning rates up to 0.03 give the same bad MSE, but after that larger learning rates improve performance. Interestingly, we can actually solve the task, but only if we choose learning rates right at the edge of diverging.&lt;/p&gt;

&lt;p&gt;It should be noted that while the above plot paints a deceptively simple picture, it is not true in general that higher learning rate is better. The experiment seems robust against random seeds for initialization and data generation, but is quite fragile against changes in other hyperparameters, such as the number of hidden neurons and scale of initialization.&lt;/p&gt;

&lt;h2 id=&quot;label-noise&quot;&gt;Label noise&lt;/h2&gt;
&lt;p&gt;We can achieve a similar effect without huge learning rates by adding “label noise”. Label noise is a form a &lt;em&gt;data augmentation&lt;/em&gt;, where we generate more training data by modifying the original data. In each iteration of gradient descent, we will replace the training targets by the original targets plus some noise. We choose standard normal noise.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to train with label noise&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;trainModelWithLabelNoise&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Gradient descent
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Return test MSE
&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'MSE = %.2f'&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trainModelWithLabelNoise&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;makeModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.03&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100000&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;Training with learning rate 0.03 (which previously gave MSE = 0.40) now gives MSE = 0.01! However, notice it required 100000 gradient descent steps to achieve this. Without label noise we only needed 1000 steps to converge (then the gradient is practically zero, meaning that parameters stop changing and running longer doesn’t change anything). Lower learning rates require even more steps.&lt;/p&gt;

&lt;h2 id=&quot;the-explanation&quot;&gt;The explanation&lt;/h2&gt;
&lt;p&gt;We will derive an explicit regularizer which approximates the implicit regularization by large step gradient descent.&lt;/p&gt;

&lt;p&gt;Let’s think of gradient descent as an approximation to &lt;em&gt;gradient flow&lt;/em&gt; (think of gradient descent with infinitesimal step length). We have some loss function \(L\) which we optimize over some parameters \(\theta\). In our case \(L = \frac{1}{n}\sum_{i=1}^n(P(X_i)-y_i)^2\), where \(P(X)\) is shorthand for \(\text{predict}(X)\), and we can stack the parameters into a vector \(\theta = (A,b)\).&lt;/p&gt;

&lt;p&gt;Consider a single step of gradient descent. We start the step with parameters \(\theta(0)\) and go to \(\theta(0)-\eta\nabla L\), where \(\eta\) is the learning rate. The gradient flow solution satisfies \(\dot{\theta} = -\nabla L\). Differentiating again we get \(\ddot{\theta} = -\nabla^2 L\dot{\theta} = \nabla^2 L \nabla L\). We can therefore Taylor expand&lt;/p&gt;

\[\theta(\eta) = \theta(0) + \eta\dot{\theta} + \frac{\eta^2}{2}\ddot{\theta} + O(\eta^3) = \theta(0) - \eta\nabla L + \frac{\eta^2}{2}\nabla^2 L\nabla L + O(\eta^3).\]

&lt;p&gt;We see that the gradient descent step only captures the first two terms, giving a truncation error of \(O(\eta^2)\).&lt;/p&gt;

&lt;p&gt;Let’s change the loss function to&lt;/p&gt;

\[\tilde{L} = L + \frac{\eta}{4}\|\nabla L\|_2^2\]

&lt;p&gt;and look at the new gradient flow \(\dot{\theta} = -\nabla \tilde{L} = -\nabla L - \frac{\eta}{4}\nabla \|\nabla L\|_2^2 = - \nabla L - \frac{\eta}{2}\nabla^2 L \nabla L\). Differentiating again we have \(\ddot{\theta} = \nabla^2 L \nabla L + O(\eta)\). The Taylor expansion is now&lt;/p&gt;

\[\theta(\eta) = \theta(0) + \eta(- \nabla L - \frac{\eta}{2}\nabla^2 L \nabla L) + \frac{\eta^2}{2}\nabla^2 L\nabla L + O(\eta^3) = \theta(0) - \eta\nabla L + O(\eta^3).\]

&lt;p&gt;The gradient descent step &lt;em&gt;with respect to the original loss&lt;/em&gt; approximates the gradient flow with respect to the new loss with truncation error only \(O(\eta^3)\)!&lt;/p&gt;

&lt;p&gt;In this view, it is hence more accurate to say that gradient descent is optimizing \(\tilde{L} = L + \frac{\eta}{4}\|\nabla L\|_2^2\) than \(L\). So we implicitly added the regularizer \(\frac{\eta}{4}\|\nabla L\|_2^2\). When the learning rate \(\eta\) is large, this regularizer is not negligible.&lt;/p&gt;

&lt;p&gt;I found the neat derivation presented above in &lt;a href=&quot;https://arxiv.org/abs/2009.11162&quot;&gt;Implicit Gradient Regularization&lt;/a&gt;, there you can find more details.&lt;/p&gt;

&lt;h1 id=&quot;label-noise-1&quot;&gt;Label noise&lt;/h1&gt;
&lt;p&gt;But wait, optimizing \(L\) is the same as optimizing \(\tilde{L}\), right? A minimizer of \(L\) has gradient zero, so it is also a minimizer of \(\tilde{L}\). Well, we are in the overparameterized setting, so the optimization path might change &lt;em&gt;which&lt;/em&gt; minimizer we end up at. When we add label noise it becomes clearer.&lt;/p&gt;

&lt;p&gt;With label noise we have \(L = E_{z_i \sim \mathcal{N}(0,1)}\frac{1}{n}\sum_{i=1}^n(P(X_i)-y_i-z_i)^2\) and \(\tilde{L} = L + \frac{\eta}{4} E_{z_i \sim \mathcal{N}(0,1)}\|\nabla \frac{1}{n}\sum_{i=1}^n(P(X_i)-y_i-z_i)^2\|_2^2\). After a bit of calculation, we may simplify this to&lt;/p&gt;

\[\tilde{L} = \frac{1}{n}\sum_{i=1}^n(P(X_i)-y_i)^2 + \eta\frac{1}{n^2}\sum_{i=1}^n \|\nabla P(X_i)\|_2^2 + \eta\Big[\frac{1}{n}\sum_{i=1}^n (P(X_i)-y_i)\cdot \nabla P(X_i)\Big]^2 + 1.\]

&lt;p&gt;The \(+1\) doesn’t matter for optimization, so we remove it. At convergence \(P(X_i) \approx y_i\), so we have&lt;/p&gt;

\[\tilde{L} \approx \eta\frac{1}{n^2}\sum_{i=1}^n \|\nabla P(X_i)\|_2^2\]

&lt;p&gt;One way to interpret this is that among the parameter configurations fitting the training data, we choose the one minimizing the gradient of the neural network output with respect to the parameters. We may further write out and analyze \(\|\nabla P(X_i)\|_2^2\) for the case of our neural network, and that will likely show why this regularizer is useful for solving our regression task. I will not do that here.&lt;/p&gt;

&lt;p&gt;One of the main insights of this implicit regularizer is the multiplicative factor \(\eta\). It implies the strength of the regularizer is proportional to the learning rate. Consequently, while we can often fit the training data roughly \(\frac{1}{\eta}\) iterations of gradient descent, it might take \(\frac{1}{\eta^2}\) iterations to optimize the regularizer.&lt;/p&gt;

&lt;h1 id=&quot;comparing-with-small-initialization&quot;&gt;Comparing with small initialization&lt;/h1&gt;
&lt;p&gt;If you read the &lt;a href=&quot;/2022/07/06/dln_classifier.html&quot;&gt;previous blog post&lt;/a&gt;, you might wonder if we can apply the implicit bias given by small initialization to solve the regression task given here. Indeed you can! Simply scaling down the outputs of the neural network by a factor 100, we get MSE = 0.00 with learning rate 0.01.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to test small initialization&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;makeModelSmallInit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;m&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Number of hidden nodes
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;m&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;m&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;manual_seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'MSE = %.2f'&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;makeModelSmallInit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;When writing the previous blog entry I tried the reverse: applying regularization by large step sizes and label noise to the classification task. However, this implicit bias was not useful for solving that task.&lt;/p&gt;</content><author><name></name></author><summary type="html">I will demonstrate how large learning rates can lead to implicit biases in a simple regression task. Code to reproduce the results can be found in spoilers. We will use pytorch.</summary></entry><entry><title type="html">Implicit bias by small initialization: A simple classification task where deeper is better</title><link href="https://johanwind.github.io/2022/07/06/dln_classifier.html" rel="alternate" type="text/html" title="Implicit bias by small initialization: A simple classification task where deeper is better" /><published>2022-07-06T00:00:00+00:00</published><updated>2022-07-06T00:00:00+00:00</updated><id>https://johanwind.github.io/2022/07/06/dln_classifier</id><content type="html" xml:base="https://johanwind.github.io/2022/07/06/dln_classifier.html">&lt;p&gt;In this blog post I will show an example of implicit bias on a synthetic classification task. I will put code to reproduce the results in spoilers. We will implement everything using pytorch.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to import packages&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;manual_seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;h2 id=&quot;the-classification-task&quot;&gt;The classification task&lt;/h2&gt;
&lt;p&gt;Let’s classify \(d = 10\) dimensional input features into \(k = 10\) classes using \(n = 100\) training examples, and generate \(n = 100\) data points for testing.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to set constants&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Input dimension
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Classes
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Training examples / testing points&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;We generate independently normally distributed data points in d = 10 dimensions. The classes are determined by the direction of the first 2 dimensions. Note that the other 8 dimensions are only there to confuse the classifier. In the plots below: on the left we see that the classes are separated into 10 pizza slices by angle of the first 2 feature dimensions, on the right we see other dimensions which are useless for classification.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/dln_classifier_data.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to generate data&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Generate random points
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;atan2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pi&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;long&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Classify points
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,:],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Train/test split
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:,:],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;h2 id=&quot;the-models&quot;&gt;The models&lt;/h2&gt;
&lt;p&gt;First, let us consider the linear classifier given by a \(d \times k\) matrix \(W\) where we predict the class of the data point \(X\) as the class whose column in \(W\) has the maximum inner product with \(X\), i.e&lt;/p&gt;

\[\text{predict}(X) = \text{argmax}_{i \in \{1,\dots,k\}} X \cdot W_i.\]

&lt;p&gt;We optimize the model by gradient descent until we have mean cross entropy loss at most 0.01 on the training data. Note that 0.01 cross entropy is quite low, meaning we perfectly fit the training data (100% accuracy), which is only possible since our models are in the overparameterized regime.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to train model&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e100&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Optimize until mean cross entropy loss is &amp;lt;= 0.01
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cross_entropy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Gradient descent
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
  &lt;span class=&quot;c1&quot;&gt;# Return test accuracy
&lt;/span&gt;  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;argmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;item&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;We can now train the simple linear model.&lt;/p&gt;
&lt;details&gt; &lt;summary&gt;Code to test simple linear model&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Test accurracy:&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'%'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;We get an accuracy of 53%. The hyperparameters for learning rate (lr) and stopping threshold (0.01) don’t seem to matter much if they are small enough (but making them smaller takes more time to optimize).&lt;/p&gt;

&lt;p&gt;Let us now apply the rule of thumb from deep learning that “deeper is better” and add another layer.  We parameterize \(W = AB\) for a \(d \times k\) matrix \(A\) and \(k\times k\) matrix \(B\).&lt;/p&gt;

\[\text{predict}(X) = \text{argmax}_{i \in \{1,\dots,k\}} X \cdot [AB]_i.\]

&lt;p&gt;For the previous parameterization, the optimization problem was convex, but this time it is not. This means initialization matters (in particular, if we initialize \(A = B = 0\) we will have gradient 0 and not get anywhere), let’s therefore initialize by a standard initialization from deep learning called Xavier/Glorot normal initialization. Let us also run 10 times to average out the random initialization.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to test 2 layer model&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Equivalent to A = th.randn(d,k)*(2/(d+k))**.5
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Equivalent to B = th.randn(k,k)*(2/(k+k))**.5
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Test accurracy: %.1f ± %.1f %%&quot;&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;We get an accuracy of 60.3% with standard deviation 0.6%. The accuracy improved! How can this happen? Classical intuition tells us that parameterizing a matrix as a product of matrices is useless, as we can only represent exactly the same set of classifiers. The key is that we are in the overparameterized regime, where gradient descent has many solutions to choose from. The parameterization changes the &lt;em&gt;implicit bias&lt;/em&gt; of the model when it is trained by gradient descent, causing it to choose a different solution.&lt;/p&gt;

&lt;p&gt;Let’s go deeper!&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/dln_classifier_depth.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Clearly more layers give better accuracy.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to test models of various depths&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;L&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;L&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;==&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;product&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;product&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;layer&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;product&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;L&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;L&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;3e-2&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Test accurracy: %.1f ± %.1f %%&quot;&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;h2 id=&quot;the-explanation&quot;&gt;The explanation&lt;/h2&gt;
&lt;p&gt;To understand why depth improves generalization accuracy, we need to note that our classification problem has “low rank” in a certain sense. Consider the optimal classifier matrix \(W\) with respect to&lt;/p&gt;

\[\text{predict}(X) = \text{argmax}_{i \in \{1,\dots,k\}} X \cdot W_i.\]

&lt;p&gt;The i’th column (so the direction of class i-1 in the 0-indexed code) of this matrix is given by \(W_i = \left(\sin\left(\frac{\pi(2i-1-k)}{k}\right), \cos\left(\frac{\pi(2i-1-k)}{k}\right), 0, \dots, 0\right)^\top\). Note that since only the first two rows of \(W\) are non-zero, we have \(\text{rank}(W) = 2\). Intuitively, there are much fewer matrices with rank 2 than general matrices, so if we can somehow implicitly bias our model towards low rank matrices (or preferably rank 2 matrices), we will likely get better classification accuracy.&lt;/p&gt;

&lt;p&gt;To demonstrate, let’s explicitly force \(W\) to be rank 2 by factoring it into a \(d \times 2\) matrix \(A\) and a \(2 \times k\) matrix \(B\) so \(W = AB\), and then repeat our experiment.&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code to test rank 2 model&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Test accurracy: %.1f ± %.1f %%&quot;&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;85% accuracy! That’s the best we’ve seen. Ok, but our other deep factorizations didn’t force the matrix to be low rank, so how did they exploit the low rank property?&lt;/p&gt;

&lt;p&gt;That is most easily seen if we scale down the outputs of a three layer factorization \(W = \frac{1}{10^4}ABC\).&lt;/p&gt;

&lt;details&gt; &lt;summary&gt;Code for near-zero initialized three layer factorization&lt;/summary&gt; 
&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;th&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xavier_normal_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e4&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trainModel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Test accurracy: %.1f ± %.1f %%&quot;&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;
 &lt;/details&gt;

&lt;p&gt;We get an accuracy of ca 84%. Plotting the singular values of the product matrix \(\frac{1}{10^4}ABC\) during training, we see what is going on.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/dln_classifier_singular_small.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;We see that the singular values show up one by one. When only the first singular value is non-negligible, we are effectively searching through rank 1 matrices. After a while it introduces the next singular value to search rank 2 matrices. Then the optimization terminates because it fits the training data. The remaining singular values are left around the initialization magnitude \(\frac{1}{10^4}\).&lt;/p&gt;

&lt;p&gt;In contrast to scaling down the outputs, if we switch \(\frac{1}{10^4}\) to \(10^3\) so \(W = 10^3 ABC\) (and adapt the learning rate to lr = 0.001), we get accuracy 57.5% with standard deviation 1%. So we are back down to accuracies around the one layer case, since we removed most of the implicit bias.&lt;/p&gt;

&lt;p&gt;Now you might wonder how we got any benefit from depth in our original setup, where we seemingly didn’t have small initialization. The clue is that it is really the relative size of initialization to final size which determines whether we are in the “small initialization regime”. Since our overfitting of the cross entropy loss results in the final matrix \(W = ABC\) having singular values around size 100, the initialization of size around 1 becomes small in comparison. You can see this in the following plot of the evolution of singular values for the original 3 layer model.
&lt;img src=&quot;/images/dln_classifier_singular_normal.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The implicit bias demonstrated in this blog has many names given by different researchers, such as near-zero / small initialization regime, following “saddle-to-saddle” dynamics, anti-NTK regime and rich regime / rich limit. The experiment itself was motivated by Arora’s paper on &lt;a href=&quot;https://arxiv.org/abs/1905.13655&quot;&gt;Implicit Regularization in Deep Matrix Factorization&lt;/a&gt;. In that paper, they describe the differential equations for the singular values and what causes them to show up one by one to give the low rank implicit bias. Simply put, increasing the depth causes the effect to strengthen. The experiment in this paper is an adaptation of their experiment from matrix completion to classification.&lt;/p&gt;</content><author><name></name></author><summary type="html">In this blog post I will show an example of implicit bias on a synthetic classification task. I will put code to reproduce the results in spoilers. We will implement everything using pytorch.</summary></entry><entry><title type="html">Kernel methods are basically overparameterized linear regression</title><link href="https://johanwind.github.io/2022/06/29/kernel.html" rel="alternate" type="text/html" title="Kernel methods are basically overparameterized linear regression" /><published>2022-06-29T00:00:00+00:00</published><updated>2022-06-29T00:00:00+00:00</updated><id>https://johanwind.github.io/2022/06/29/kernel</id><content type="html" xml:base="https://johanwind.github.io/2022/06/29/kernel.html">&lt;p&gt;In this blog post I explain kernel methods and the intuition that they’re “basically just overparameterized linear regression”. First, I will explain what I mean by kernel methods and overparameterized linear regression. I will view them as two methods for solving the following problem:&lt;/p&gt;

&lt;p&gt;We are given \(n\) data points \(x_i\) where each data point is a d-dimensional vector. Each data point has a scalar target value \(y_i\). For example, the data points could describe houses by size, location and year of construction and the target value could be the price of the house. Next, we are given a new data point \(z\) (size, location and year of construction of a new house) and want to predict its target value (price of the new house).&lt;/p&gt;

&lt;h3 id=&quot;informal-description-of-kernel-methods-for-regression&quot;&gt;Informal description of kernel methods for regression&lt;/h3&gt;
&lt;p&gt;We may apply a kernel method to this problem as follows:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;
    &lt;p&gt;Pick your favorite &lt;em&gt;kernel function&lt;/em&gt; k(x,y). Intuitively, this is a function measuring the similarity between x and y, giving higher values to more similar inputs. A popular one is&lt;/p&gt;

\[k(x,y) = e^{-\|x-y\|_2^2}.\]
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;Build the \(n \times n\) “kernel matrix” by&lt;/p&gt;

\[K_{ij} = k(x_i,x_j).\]
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;Solve for the coefficient vector \(\alpha\)&lt;/p&gt;

\[K \alpha = y.\]
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;Calculate similarities to the new data point \(z\) and use this to produce a prediction&lt;/p&gt;

\[prediction = \sum_{i=1}^d \alpha_i k(x_i, z).\]
  &lt;/li&gt;
&lt;/ol&gt;

&lt;h3 id=&quot;overparameterized-linear-regression&quot;&gt;Overparameterized linear regression&lt;/h3&gt;
&lt;p&gt;Let us now instead apply linear regression to the problem. Since the features \(x_i\) might be related to the target values \(y_i\) in a nonlinear way, it is often useful to preprocess the features before fitting a linear model. For example, if \(x_{i1}\) is the size of house \(i\) and \(x_{i2}\) its year of construction, we might want to add the feature \(x_{i1}\cdot x_{i2}\) to account for nonlinear interactions between \(x_{i1}\) and \(x_{i2}\). Maybe we also want to add something like \(\sin(x_{i2})\) or \(\exp(x_{i1}+x_{i2})\) or the constant value \(1\) (often called intercept). Or we just add all monomials (\(x_{i1}^{p_1}x_{i2}^{p_2}\dots\)) of the input features up to degree \(10\). The possibilities are endless. Let’s assume we processed our old d-dimensional features \(x_i\) into new m-dimensional features \(\tilde{x}_i = \varphi(x_i)\), where \(\varphi\) is our feature transform. Let’s also stack all these new features into a new \(n \times m\) matrix \(\tilde X\). Construct \(\tilde{z} = \varphi(z)\) in the same way.&lt;/p&gt;

&lt;p&gt;What if we end up with more features \(m\) than data points \(n\)? This is what we call an overparameterized model. Usually this means we can perfectly fit all the targets \(y_i\). And moreover, we can perfectly fit the targets in many different ways. It makes sense to pick some linear coefficients \(\beta\) such that \(\tilde{X}\beta = y\), i.e we fit the data. And let’s pick the one among those with minimal 2-norm \(\|\beta\|_2\). It turns out that we can calculate this by&lt;/p&gt;

\[\beta = \tilde{X}^\top(\tilde{X}\tilde{X}^\top)^{-1}y.\]

&lt;p&gt;Then we can make the prediction&lt;/p&gt;

\[prediction = \tilde{z} \beta.\]

&lt;h3 id=&quot;numerical-example&quot;&gt;Numerical example&lt;/h3&gt;
&lt;p&gt;To illustrate the two methods described above, I used them to interpolate the function \(\sin(2\pi x)\) in the interval \(0 \le x \le 1\) from 5 evenly spaced points. In this case the features \(x_i\) are simply the positions of the data points. For the overparamterized linear regression I used the 10 polynomial features \(\left(\frac{x_i}{10}\right)^p\) for \(p = 0,\dots,9\).&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/kernel_olr.png&quot; style=&quot;width: 100%; display: block; margin: 0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Here’s the python code:&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linspace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;   &lt;span class=&quot;c1&quot;&gt;# Training data
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pi&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# Training targets
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linspace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Data points to predict
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Kernel method
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;alpha&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linalg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;solve&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;kernel_prediction&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;alpha&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Overparameterized linear regression
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;features&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;X_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;features&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linalg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;solve&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;linear_prediction&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;features&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'o'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'Data points'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_prediction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'--'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'Kernel method'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;linear_prediction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'-.'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'Linear regression'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pi&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'Original function'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;legend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;show&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now that I’ve described the two methods, we can get to the point.&lt;/p&gt;
&lt;h3 id=&quot;kernel-methods-are-basically-overparameterized-linear-regression&quot;&gt;Kernel methods are basically overparameterized linear regression&lt;/h3&gt;
&lt;p&gt;Recall \(\varphi\) as the feature transform taking the old features to the new features in overparameterized linear regression. Now define the kernel function \(k(x,y) = \varphi(x)^\top \varphi(y)\). Then the kernel matrix becomes \(K = \tilde{X}\tilde{X}^\top\) and \(\alpha = K^{-1}y = (\tilde{X}\tilde{X}^\top)^{-1}y\). The final prediction becomes&lt;/p&gt;

\[prediction = \tilde{z}\tilde{X}^\top(\tilde{X}\tilde{X}^\top)^{-1}y.\]

&lt;p&gt;Does this look familiar? That’s because it is exactly the same prediction that we would get from doing overparameterized linear regression! So if we have a feature transform \(\varphi\), we can construct a corresponding kernel function \(k(x,y) = \varphi(x)^\top \varphi(y)\). What about the converse?&lt;/p&gt;

&lt;p&gt;It turns out that under some technical assumptions (look up Mercer’s theorem if you’re interested) on the kernel function \(k(x,y)\), we can find feature transforms \(\varphi\) such that \(k(x,y) \approx \varphi(x)^\top \varphi(y)\) to arbitrary accuracy. And importantly \(\varphi\) can be chosen without knowing the data points \(x_i\) or targets \(y_i\). So the trade-off here is that to get an accurate approximation of \(k(x,y)\) the feature transform \(\varphi\) might output a lot of features.&lt;/p&gt;

&lt;p&gt;We may pick some high accuracy \(10^{-100}\), much higher than what we usually compute with, and find some \(\varphi\) such that \(\lvert k(x,y)-\varphi(x)^\top \varphi(y)\rvert &amp;lt; 10^{-100}\) for every relevant x and y. This might require \(10^{10^{10}}\) features, but that’s fine. The point is that overparameterized linear regression on these features is for all practical purposes the same as the kernel method. Note again that the feature transform is independent of \(x_i\) and \(y_i\). Kernel methods hence inherit many properties and limitations of linear regression.&lt;/p&gt;

&lt;p&gt;Some applications:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Linear regression (in view of predictions) only depends on the inner products between data points. For example, duplicating a feature is the same as scaling it up by \(\sqrt 2\), and having features \(x_{i1}\) and \(x_{i2}\) is the same as having features \(\frac{x_{i1}+x_{i2}}{\sqrt 2}\) and \(\frac{x_{i1}-x_{i2}}{\sqrt 2}\).&lt;/li&gt;
  &lt;li&gt;In terms of notation, I often find it easier to work with (and think about) a single feature vector for each data point, than considering pair-wise similarities through a kernel function. I feel like it makes it easier to apply linear algebra notation and operations.&lt;/li&gt;
  &lt;li&gt;This one is a bit vague: Kernel methods can’t perform “feature learning”, i.e they are stuck with a fixed set of features (given by the kernel function) in a way where they can’t disproportionally focus on the important features. This is in contrast to for example neural networks which we hope will learn good representations of the data useful for transfer learning, ignoring irrelevant features, etc.&lt;/li&gt;
&lt;/ul&gt;</content><author><name></name></author><summary type="html">In this blog post I explain kernel methods and the intuition that they’re “basically just overparameterized linear regression”. First, I will explain what I mean by kernel methods and overparameterized linear regression. I will view them as two methods for solving the following problem:</summary></entry></feed>