<?xml version='1.0' encoding='UTF-8'?>
<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en">
  <id>https://blog.dask.org</id>
  <title>Dask Working Notes - Posts by Nicholas Sofroniew</title>
  <updated>2026-03-05T15:05:20.249196+00:00</updated>
  <link href="https://blog.dask.org"/>
  <link href="https://blog.dask.org/blog/author/nicholas-sofroniew/atom.xml" rel="self"/>
  <generator uri="https://ablog.readthedocs.io/" version="0.11.12">ABlog</generator>
  <entry>
    <id>https://blog.dask.org/2021/03/29/apply-pretrained-pytorch-model/</id>
    <title>Dask with PyTorch for large scale image analysis</title>
    <updated>2021-03-29T00:00:00+00:00</updated>
    <author>
      <name>Genevieve Buckley</name>
    </author>
    <content type="html">&lt;aside class="system-message"&gt;
&lt;p class="system-message-title"&gt;System Message: WARNING/2 (&lt;span class="docutils literal"&gt;/opt/build/repo/2021/03/29/apply-pretrained-pytorch-model.md&lt;/span&gt;, line 9)&lt;/p&gt;
&lt;p&gt;Document headings start at H2, not H1 [myst.header]&lt;/p&gt;
&lt;/aside&gt;
&lt;section id="executive-summary"&gt;

&lt;p&gt;This post explores applying a pre-trained &lt;a class="reference external" href="https://pytorch.org/"&gt;PyTorch&lt;/a&gt; model in parallel with Dask Array.&lt;/p&gt;
&lt;p&gt;We cover a simple example applying a pre-trained UNet to a stack of images to generate features for every pixel.&lt;/p&gt;
&lt;aside class="system-message"&gt;
&lt;p class="system-message-title"&gt;System Message: WARNING/2 (&lt;span class="docutils literal"&gt;/opt/build/repo/2021/03/29/apply-pretrained-pytorch-model.md&lt;/span&gt;, line 15)&lt;/p&gt;
&lt;p&gt;Document headings start at H2, not H1 [myst.header]&lt;/p&gt;
&lt;/aside&gt;
&lt;/section&gt;
&lt;section id="a-worked-example"&gt;
&lt;h1&gt;A Worked Example&lt;/h1&gt;
&lt;p&gt;Let’s start with an example applying a pre-trained &lt;a class="reference external" href="https://arxiv.org/abs/1505.04597"&gt;UNet&lt;/a&gt; to a stack of light sheet microscopy data.&lt;/p&gt;
&lt;p&gt;In this example, we:&lt;/p&gt;
&lt;ol class="arabic simple"&gt;
&lt;li&gt;&lt;p&gt;Load the image data from Zarr into a multi-chunked Dask array&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;Load a pre-trained PyTorch model that featurizes images&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;Construct a function to apply the model onto each chunk&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;Apply that function across the Dask array with the dask.array.map_blocks function.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;Store the result back into Zarr format&lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;
&lt;section id="step-1-load-the-image-data"&gt;
&lt;h2&gt;Step 1. Load the image data&lt;/h2&gt;
&lt;p&gt;First, we load the image data into a Dask array.&lt;/p&gt;
&lt;p&gt;The example dataset we’re using here is lattice lightsheet microscopy of the tail region of a zebrafish embryo. It is described in &lt;a class="reference external" href="http://dx.doi.org/10.1126/science.aaq1392"&gt;this Science paper&lt;/a&gt; (see Figure 4), and provided with permission from Srigokul Upadhyayula.&lt;/p&gt;
&lt;blockquote&gt;
&lt;div&gt;&lt;p&gt;Liu &lt;em&gt;et al.&lt;/em&gt; 2018 “Observing the cell in its native state: Imaging subcellular dynamics in multicellular organisms” &lt;em&gt;Science&lt;/em&gt;, Vol. 360, Issue 6386, eaaq1392 DOI: 10.1126/science.aaq1392 (&lt;a class="reference external" href="http://dx.doi.org/10.1126/science.aaq1392"&gt;link&lt;/a&gt;)&lt;/p&gt;
&lt;/div&gt;&lt;/blockquote&gt;
&lt;p&gt;This is the same data that we analysed in our last &lt;a class="reference external" href="https://blog.dask.org/2019/08/09/image-itk"&gt;blogpost on Dask and ITK&lt;/a&gt;. You should note the similarities to that workflow even though we are now using new libaries and performing different analyses.&lt;/p&gt;
&lt;div class="highlight-default notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="n"&gt;cd&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;/Users/nicholassofroniew/Github/image-demos/data/LLSM&amp;#39;&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;div class="highlight-python notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# Load our data&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;dask.array&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="k"&gt;as&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;da&lt;/span&gt;
&lt;span class="n"&gt;imgs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;da&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_zarr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;AOLLSM_m4_560nm.zarr&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;imgs&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;div class="highlight-default notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="n"&gt;dask&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="o"&gt;&amp;lt;&lt;/span&gt;&lt;span class="n"&gt;from&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;zarr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;199&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;768&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1024&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chunksize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;768&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1024&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/section&gt;
&lt;section id="step-2-load-a-pre-trained-pytorch-model"&gt;
&lt;h2&gt;Step 2. Load a pre-trained PyTorch model&lt;/h2&gt;
&lt;p&gt;Next, we load our pre-trained UNet model.&lt;/p&gt;
&lt;p&gt;This UNet model takes in an 2D image and returns a 2D x 16 array, where each pixel is now associate with a feature vector of length 16.&lt;/p&gt;
&lt;p&gt;We thank Mars Huang for training this particular UNet on a corpous of biological images to produce biologically relevant feature vectors, as part of his work on &lt;a class="reference external" href="https://github.com/transformify-plugins/segmentify"&gt;interactive bio-image segmentation&lt;/a&gt;. These features can then be used for more downstream image processing tasks such as image segmentation.&lt;/p&gt;
&lt;div class="highlight-python notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# Load our pretrained UNet¶&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;segmentify.model&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;UNet&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;layers&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;load_unet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;span class="w"&gt;    &lt;/span&gt;&lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Load a pretrained UNet model.&amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;

    &lt;span class="c1"&gt;# load in saved model&lt;/span&gt;
    &lt;span class="n"&gt;pth&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;model_args&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pth&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;model_args&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;model_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pth&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;model_state&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;UNet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;model_args&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# remove last layer and activation&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;segment&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layers&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Identity&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;activate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layers&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Identity&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;load_unet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;HPA_3.pth&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/section&gt;
&lt;section id="step-3-construct-a-function-to-apply-the-model-to-each-chunk"&gt;
&lt;h2&gt;Step 3. Construct a function to apply the model to each chunk&lt;/h2&gt;
&lt;p&gt;We make a function to apply our pre-trained UNet model to each chunk of the Dask array.&lt;/p&gt;
&lt;p&gt;Because Dask arrays are just made out of Numpy arrays which are easily converted to Torch arrays, we’re able to leverage the power of machine learning at scale.&lt;/p&gt;
&lt;div class="highlight-python notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# Apply UNet featurization&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;numpy&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="k"&gt;as&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;np&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;unet_featurize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;span class="w"&gt;    &lt;/span&gt;&lt;span class="sd"&gt;&amp;quot;&amp;quot;&amp;quot;Featurize pixels in an image using pretrained UNet model.&lt;/span&gt;
&lt;span class="sd"&gt;    &amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;
    &lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;numpy&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="k"&gt;as&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;np&lt;/span&gt;
    &lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;torch&lt;/span&gt;

    &lt;span class="c1"&gt;# Extract the 2D image data from the Dask array&lt;/span&gt;
    &lt;span class="c1"&gt;# Original Dask array dimensions were (time, z-slice, y, x)&lt;/span&gt;
    &lt;span class="n"&gt;img&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="c1"&gt;# Put the data into a shape PyTorch expects&lt;/span&gt;
    &lt;span class="c1"&gt;# Expected dimensions are (Batch x Channel x Width x Height)&lt;/span&gt;
    &lt;span class="n"&gt;img&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;img&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="c1"&gt;# convert image to torch Tensor&lt;/span&gt;
    &lt;span class="n"&gt;img&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# pass image through model&lt;/span&gt;
    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="n"&gt;features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;numpy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# generate feature vectors (w,h,f)&lt;/span&gt;
    &lt;span class="n"&gt;features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="c1"&gt;# Add back the leading length-one dimensions&lt;/span&gt;
    &lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;result&lt;/span&gt;

&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;Note: Very observant readers might notice that the steps for extracting the 2D image data and then putting it into a shape PyTorch expects appear to be redundant. It is redundant for our particular example, but that might easily not have been the case.&lt;/p&gt;
&lt;p&gt;To explain this in more detail, the UNet expects 4D input, with dimensions &lt;code class="docutils literal notranslate"&gt;&lt;span class="pre"&gt;(Batch&lt;/span&gt; &lt;span class="pre"&gt;x&lt;/span&gt; &lt;span class="pre"&gt;Channel&lt;/span&gt; &lt;span class="pre"&gt;x&lt;/span&gt; &lt;span class="pre"&gt;Width&lt;/span&gt; &lt;span class="pre"&gt;x&lt;/span&gt; &lt;span class="pre"&gt;Height)&lt;/span&gt;&lt;/code&gt;. The original Dask array dimensions were &lt;code class="docutils literal notranslate"&gt;&lt;span class="pre"&gt;(time,&lt;/span&gt; &lt;span class="pre"&gt;z-slice,&lt;/span&gt; &lt;span class="pre"&gt;y,&lt;/span&gt; &lt;span class="pre"&gt;x)&lt;/span&gt;&lt;/code&gt;. In our example it just so happens those match in a way that makes removing and then adding the leading dimensions redundant, but depending on the shape of the original Dask array this might not have been true.&lt;/p&gt;
&lt;/section&gt;
&lt;section id="step-4-apply-that-function-across-the-dask-array"&gt;
&lt;h2&gt;Step 4. Apply that function across the Dask array&lt;/h2&gt;
&lt;p&gt;Now we apply that function to the data in our Dask array using &lt;a class="reference external" href="https://docs.dask.org/en/latest/array-api.html?highlight=map_blocks#dask.array.map_blocks"&gt;&lt;code class="docutils literal notranslate"&gt;&lt;span class="pre"&gt;dask.array.map_blocks&lt;/span&gt;&lt;/code&gt;&lt;/a&gt;.&lt;/p&gt;
&lt;div class="highlight-python notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# Apply UNet featurization&lt;/span&gt;
&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;da&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;map_blocks&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;unet_featurize&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;imgs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chunks&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;imgs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;imgs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;new_axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;out&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;div class="highlight-default notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="n"&gt;dask&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="o"&gt;&amp;lt;&lt;/span&gt;&lt;span class="n"&gt;unet_featurize&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;199&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;768&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1024&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chunksize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;768&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1024&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/section&gt;
&lt;section id="step-5-store-the-result-back-into-zarr-format"&gt;
&lt;h2&gt;Step 5. Store the result back into Zarr format&lt;/h2&gt;
&lt;p&gt;Last, we store the result from the UNet model featurization as a zarr array.&lt;/p&gt;
&lt;div class="highlight-python notranslate"&gt;&lt;div class="highlight"&gt;&lt;pre&gt;&lt;span&gt;&lt;/span&gt;&lt;span class="c1"&gt;# Trigger computation and store&lt;/span&gt;
&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_zarr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;AOLLSM_featurized.zarr&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;overwrite&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;Now we’ve saved our output, these features can be used for more downstream image processing tasks such as image segmentation.&lt;/p&gt;
&lt;/section&gt;
&lt;section id="summing-up"&gt;
&lt;h2&gt;Summing up&lt;/h2&gt;
&lt;p&gt;Here we’ve shown how to apply a pre-trained PyTorch model to a Dask array of image data.&lt;/p&gt;
&lt;p&gt;Because our Dask array chunks are Numpy arrays, they can be easily converted to Torch arrays. This way, we’re able to leverage the power of machine learning at scale.&lt;/p&gt;
&lt;p&gt;This workflow was very similar to &lt;a class="reference external" href="https://blog.dask.org/2019/08/09/image-itk"&gt;our example&lt;/a&gt; using the dask.array.map_blocks function with ITK to perform image deconvolution. This shows you can easily adapt the same type of workflow to achieve many different types of analysis with Dask.&lt;/p&gt;
&lt;/section&gt;
&lt;/section&gt;
</content>
    <link href="https://blog.dask.org/2021/03/29/apply-pretrained-pytorch-model/"/>
    <summary>Document headings start at H2, not H1 [myst.header]</summary>
    <category term="PyTorch" label="PyTorch"/>
    <category term="deeplearning" label="deep learning"/>
    <category term="imaging" label="imaging"/>
    <published>2021-03-29T00:00:00+00:00</published>
  </entry>
</feed>
