Skip to content

Commit

Permalink
Deployed 692996e with MkDocs version: 1.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyuqian committed Aug 21, 2024
1 parent e529356 commit 608138e
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 208 deletions.
295 changes: 188 additions & 107 deletions deployer/index.html

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions mnist/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@

<h1>MNIST Example</h1>

<p>This is a trivial MNIST example with RedCoast. Runnable by
<p>This is a trivial MNIST example with RedCoast (<code>pip install redco==0.4.22</code>). Runnable by
<div class="highlight"><pre><span></span><code>python main.py
</code></pre></div></p>
<p>To simulate multiple devices in cpu-only envs,
Expand Down Expand Up @@ -601,14 +601,14 @@ <h3 id="source-code">Source Code<a class="headerlink" href="#source-code" title=


<span class="c1"># Loss function converting model inputs to a scalar loss</span>
<span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">train_rng</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">is_training</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">is_training</span><span class="p">):</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">apply_fn</span><span class="p">({</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">params</span><span class="p">},</span> <span class="n">batch</span><span class="p">[</span><span class="s1">&#39;images&#39;</span><span class="p">])</span>
<span class="k">return</span> <span class="n">optax</span><span class="o">.</span><span class="n">softmax_cross_entropy_with_integer_labels</span><span class="p">(</span>
<span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">batch</span><span class="p">[</span><span class="s1">&#39;labels&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>


<span class="c1"># Predict function converting model inputs to the model outputs</span>
<span class="k">def</span> <span class="nf">pred_fn</span><span class="p">(</span><span class="n">pred_rng</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">pred_fn</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
<span class="n">accs</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">({</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">params</span><span class="p">},</span> <span class="n">batch</span><span class="p">[</span><span class="s1">&#39;images&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="p">{</span><span class="s1">&#39;acc&#39;</span><span class="p">:</span> <span class="n">accs</span><span class="p">}</span>

Expand Down
Binary file modified objects.inv
Binary file not shown.
36 changes: 6 additions & 30 deletions predictor/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -848,13 +848,7 @@ <h2 id="redco.predictors.predictor.Predictor" class="doc doc-heading">
<span class="normal">152</span>
<span class="normal">153</span>
<span class="normal">154</span>
<span class="normal">155</span>
<span class="normal">156</span>
<span class="normal">157</span>
<span class="normal">158</span>
<span class="normal">159</span>
<span class="normal">160</span>
<span class="normal">161</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">Predictor</span><span class="p">:</span>
<span class="normal">155</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">Predictor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Predictor class managing distributed inference process.</span>

<span class="sd"> Attributes:</span>
Expand Down Expand Up @@ -961,15 +955,9 @@ <h2 id="redco.predictors.predictor.Predictor" class="doc doc-heading">
<span class="bp">self</span><span class="o">.</span><span class="n">setup_running_step</span><span class="p">(</span>
<span class="n">dummy_batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">params_shape_or_params</span><span class="o">=</span><span class="n">params</span><span class="p">)</span>

<span class="n">pred_rng</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deployer</span><span class="o">.</span><span class="n">gen_rng</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mesh</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">pred_rng</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span>
<span class="n">pred_rng</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">process_count</span><span class="p">())[</span><span class="n">jax</span><span class="o">.</span><span class="n">process_index</span><span class="p">()]</span>
<span class="n">pred_rng</span> <span class="o">=</span> <span class="n">shard_prng_key</span><span class="p">(</span><span class="n">pred_rng</span><span class="p">)</span>

<span class="n">rng</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deployer</span><span class="o">.</span><span class="n">gen_model_step_rng</span><span class="p">()</span>
<span class="n">batch_preds_with_idxes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deployer</span><span class="o">.</span><span class="n">run_model_step</span><span class="p">(</span>
<span class="n">step_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_p_pred_step</span><span class="p">,</span>
<span class="n">input_args</span><span class="o">=</span><span class="p">(</span><span class="n">pred_rng</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">))</span>
<span class="n">step_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_p_pred_step</span><span class="p">,</span> <span class="n">input_args</span><span class="o">=</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">))</span>
<span class="n">batch_preds</span> <span class="o">=</span> <span class="n">process_batch_preds</span><span class="p">(</span>
<span class="n">batch_preds_with_idxes</span><span class="o">=</span><span class="n">batch_preds_with_idxes</span><span class="p">,</span> <span class="n">mesh</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mesh</span><span class="p">)</span>
<span class="n">batch_preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_output_fn</span><span class="p">(</span><span class="n">batch_preds</span><span class="p">)</span>
Expand Down Expand Up @@ -1383,13 +1371,7 @@ <h3 id="redco.predictors.predictor.Predictor.predict" class="doc doc-heading">
<span class="normal">147</span>
<span class="normal">148</span>
<span class="normal">149</span>
<span class="normal">150</span>
<span class="normal">151</span>
<span class="normal">152</span>
<span class="normal">153</span>
<span class="normal">154</span>
<span class="normal">155</span>
<span class="normal">156</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="normal">150</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">examples</span><span class="p">,</span>
<span class="n">per_device_batch_size</span><span class="p">,</span>
<span class="n">params</span><span class="p">,</span>
Expand Down Expand Up @@ -1440,15 +1422,9 @@ <h3 id="redco.predictors.predictor.Predictor.predict" class="doc doc-heading">
<span class="bp">self</span><span class="o">.</span><span class="n">setup_running_step</span><span class="p">(</span>
<span class="n">dummy_batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">params_shape_or_params</span><span class="o">=</span><span class="n">params</span><span class="p">)</span>

<span class="n">pred_rng</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deployer</span><span class="o">.</span><span class="n">gen_rng</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mesh</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">pred_rng</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span>
<span class="n">pred_rng</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">process_count</span><span class="p">())[</span><span class="n">jax</span><span class="o">.</span><span class="n">process_index</span><span class="p">()]</span>
<span class="n">pred_rng</span> <span class="o">=</span> <span class="n">shard_prng_key</span><span class="p">(</span><span class="n">pred_rng</span><span class="p">)</span>

<span class="n">rng</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deployer</span><span class="o">.</span><span class="n">gen_model_step_rng</span><span class="p">()</span>
<span class="n">batch_preds_with_idxes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deployer</span><span class="o">.</span><span class="n">run_model_step</span><span class="p">(</span>
<span class="n">step_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_p_pred_step</span><span class="p">,</span>
<span class="n">input_args</span><span class="o">=</span><span class="p">(</span><span class="n">pred_rng</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">))</span>
<span class="n">step_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_p_pred_step</span><span class="p">,</span> <span class="n">input_args</span><span class="o">=</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">))</span>
<span class="n">batch_preds</span> <span class="o">=</span> <span class="n">process_batch_preds</span><span class="p">(</span>
<span class="n">batch_preds_with_idxes</span><span class="o">=</span><span class="n">batch_preds_with_idxes</span><span class="p">,</span> <span class="n">mesh</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mesh</span><span class="p">)</span>
<span class="n">batch_preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_output_fn</span><span class="p">(</span><span class="n">batch_preds</span><span class="p">)</span>
Expand Down
2 changes: 1 addition & 1 deletion search/search_index.json

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions sitemap.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@
<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
<url>
<loc>https://tanyuqian.github.io/redco/</loc>
<lastmod>2024-08-20</lastmod>
<lastmod>2024-08-21</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>https://tanyuqian.github.io/redco/deployer/</loc>
<lastmod>2024-08-20</lastmod>
<lastmod>2024-08-21</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>https://tanyuqian.github.io/redco/mnist/</loc>
<lastmod>2024-08-20</lastmod>
<lastmod>2024-08-21</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>https://tanyuqian.github.io/redco/predictor/</loc>
<lastmod>2024-08-20</lastmod>
<lastmod>2024-08-21</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>https://tanyuqian.github.io/redco/trainer/</loc>
<lastmod>2024-08-20</lastmod>
<lastmod>2024-08-21</lastmod>
<changefreq>daily</changefreq>
</url>
</urlset>
Binary file modified sitemap.xml.gz
Binary file not shown.
Loading

0 comments on commit 608138e

Please sign in to comment.