打ち切り(censored)と切断(truncated)は欠損データの問題ですが、2つを混同しやすいので定義を図と数式で見ていきます。 このようなデータに対して普通の線形回帰を実行すると、勾配を過小評価してしまう可能性が高いことは見た目からもわかるかと思います。そのため、このようなデータに対しては通常切断回帰と打ち切り回帰(別名:トービット回帰)を実行することで、データのバイアスを除いた回帰係数を求めることができるようになります。 </span><a href="https://www.jstage.jst.go.jp/article/ojjams/24/1/24_1_129/_pdf" style="" target="_blank" rel="nofollow noopener noreferrer">https://www.jstage.jst.go.jp/article/ojjams/24/1/24_1_129/_pdf</a></p> <a class="header-anchor-link" href="#%E5%88%87%E6%96%AD%EF%BC%88truncated%EF%BC%89" aria-hidden="true"/> 切断(truncated)</h3> <p>目的変数Yがある範囲から外れるデータを単に知らないというものです。コードで内容を確認しましょう。</p> <div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token comment"># bounds外のデータは観測されていない</span> <span class="token keyword">def</span> <span class="token function">truncate_y</span><span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span><span class="token punctuation">:</span> keep <span class="token operator">=</span> <span class="token punctuation">(</span>y <span class="token operator">>=</span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&</span> <span class="token punctuation">(</span>y <span class="token operator"><=</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">return</span> <span class="token punctuation">(</span>x<span class="token punctuation">[</span>keep<span class="token punctuation">]</span><span class="token punctuation">,</span> y<span class="token punctuation">[</span>keep<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># デモデータ</span> slope<span class="token punctuation">,</span> intercept<span class="token punctuation">,</span> σ<span class="token punctuation">,</span> N <span class="token operator">=</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">200</span> x <span class="token operator">=</span> rng<span class="token punctuation">.</span>uniform<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">10</span><span class="token punctuation">,</span> <span class="token number">10</span><span class="token punctuation">,</span> N<span class="token punctuation">)</span> y <span class="token operator">=</span> rng<span class="token punctuation">.</span>normal<span class="token punctuation">(</span>loc<span class="token operator">=</span>slope <span class="token operator">*</span> x <span class="token operator">+</span> intercept<span class="token punctuation">,</span> scale<span class="token operator">=</span>σ<span class="token punctuation">)</span> bounds <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">]</span> xt<span class="token punctuation">,</span> yt <span class="token operator">=</span> truncate_y<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span> <span class="token comment"># plot</span> plt<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> <span class="token string">"."</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.7</span><span class="token punctuation">,</span> <span class="token number">0.7</span><span class="token punctuation">,</span> <span class="token number">0.7</span><span class="token punctuation">]</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>axhline<span class="token punctuation">(</span>bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token string">"r"</span><span class="token punctuation">,</span> ls<span class="token operator">=</span><span class="token string">"--"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>axhline<span class="token punctuation">(</span>bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token string">"r"</span><span class="token punctuation">,</span> ls<span class="token operator">=</span><span class="token string">"--"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">"x"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>ylabel<span class="token punctuation">(</span><span class="token string">"y"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>xt<span class="token punctuation">,</span> yt<span class="token punctuation">,</span> <span class="token string">"."</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <p>上記のコードを数式で表すと以下のようになりなります。</p> y^* = x\beta + \alpha \\ y = y^* \space (5 > y^* > -5) <a class="header-anchor-link" href="#%E6%89%93%E3%81%A1%E5%88%87%E3%82%8A%EF%BC%88censored%EF%BC%89" aria-hidden="true"/> 打ち切り(censored)</h3> <p>目的変数Yがある範囲から外れる際に、境界の外側のデータを廃棄するのではなく、境界の値を測定値として記録します。こちらもコードで内容を確認しましょう。</p> <div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token comment"># bounds外のデータは境界値に置き換えられる</span> <span class="token keyword">def</span> <span class="token function">censor_y</span><span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span><span class="token punctuation">:</span> cy <span class="token operator">=</span> copy<span class="token punctuation">(</span>y<span class="token punctuation">)</span> cy<span class="token punctuation">[</span>y <span class="token operator"><=</span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token operator">=</span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> cy<span class="token punctuation">[</span>y <span class="token operator">>=</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token operator">=</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token keyword">return</span> <span class="token punctuation">(</span>x<span class="token punctuation">,</span> cy<span class="token punctuation">)</span> <span class="token comment"># デモデータ</span> slope<span class="token punctuation">,</span> intercept<span class="token punctuation">,</span> σ<span class="token punctuation">,</span> N <span class="token operator">=</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">200</span> x <span class="token operator">=</span> rng<span class="token punctuation">.</span>uniform<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">10</span><span class="token punctuation">,</span> <span class="token number">10</span><span class="token punctuation">,</span> N<span class="token punctuation">)</span> y <span class="token operator">=</span> rng<span class="token punctuation">.</span>normal<span class="token punctuation">(</span>loc<span class="token operator">=</span>slope <span class="token operator">*</span> x <span class="token operator">+</span> intercept<span class="token punctuation">,</span> scale<span class="token operator">=</span>σ<span class="token punctuation">)</span> bounds <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">]</span> xc<span class="token punctuation">,</span> yc <span class="token operator">=</span> censor_y<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span> <span class="token comment"># plot</span> plt<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> <span class="token string">"."</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.7</span><span class="token punctuation">,</span> <span class="token number">0.7</span><span class="token punctuation">,</span> <span class="token number">0.7</span><span class="token punctuation">]</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>axhline<span class="token punctuation">(</span>bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token string">"r"</span><span class="token punctuation">,</span> ls<span class="token operator">=</span><span class="token string">"--"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>axhline<span class="token punctuation">(</span>bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token string">"r"</span><span class="token punctuation">,</span> ls<span class="token operator">=</span><span class="token string">"--"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">"x"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>ylabel<span class="token punctuation">(</span><span class="token string">"y"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>xc<span class="token punctuation">,</span> yc<span class="token punctuation">,</span> <span class="token string">"."</span><span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <p>上記のコードを数式で表すと以下のようになりなります。</p> y^* = x\beta + \alpha \\ y = \left\{ \begin{array}{ll} 5 & (y^* \ge 5) \\ y^* & (5 \ge y^* \ge -5) \\ -5 & (-5 \ge y^*) \end{array} \right. <a class="header-anchor-link" href="#pymc%E3%81%A7%E3%81%AE%E5%88%86%E5%B8%83" aria-hidden="true"/> PyMCでの分布</h2> <a class="header-anchor-link" href="#%E5%88%87%E6%96%AD%EF%BC%88truncated%EF%BC%89%E5%88%86%E5%B8%83" aria-hidden="true"/> 切断(truncated)分布</h3> <p><code>pm.Truncated()</code>を使用するだけです。切断されており、下限値-1より小さい値は観測されていないことがわかります。</p> <div class="code-block-container"><pre class="language-python"><code class="language-python">d <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">.</span>dist<span class="token punctuation">(</span>mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> truncated_d <span class="token operator">=</span> pm<span class="token punctuation">.</span>Truncated<span class="token punctuation">.</span>dist<span class="token punctuation">(</span>d<span class="token punctuation">,</span> lower<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> upper<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span> samples <span class="token operator">=</span> pm<span class="token punctuation">.</span>draw<span class="token punctuation">(</span>d<span class="token punctuation">,</span> draws<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">)</span> truncated_samples <span class="token operator">=</span> pm<span class="token punctuation">.</span>draw<span class="token punctuation">(</span>truncated_d<span class="token punctuation">,</span> draws<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>hist<span class="token punctuation">(</span>samples<span class="token punctuation">,</span> density<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">"normal"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>hist<span class="token punctuation">(</span>truncated_samples<span class="token punctuation">,</span> density<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">"truncated normal"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>legend<span class="token punctuation">(</span><span class="token punctuation">)</span> <a class="header-anchor-link" href="#%E6%89%93%E3%81%A1%E5%88%87%E3%82%8A%EF%BC%88censored%EF%BC%89%E5%88%86%E5%B8%83" aria-hidden="true"/> 打ち切り(censored)分布</h3> <p><code>pm.Censored()</code>を使用するだけです。下限値-1より小さい値は全て-1として観測されていることがわかります。</p> <div class="code-block-container"><pre class="language-python"><code class="language-python">d <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">.</span>dist<span class="token punctuation">(</span>mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> censored_d <span class="token operator">=</span> pm<span class="token punctuation">.</span>Censored<span class="token punctuation">.</span>dist<span class="token punctuation">(</span>d<span class="token punctuation">,</span> lower<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> upper<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span> samples <span class="token operator">=</span> pm<span class="token punctuation">.</span>draw<span class="token punctuation">(</span>d<span class="token punctuation">,</span> draws<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">)</span> censored_samples <span class="token operator">=</span> pm<span class="token punctuation">.</span>draw<span class="token punctuation">(</span>truncated_d<span class="token punctuation">,</span> draws<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>hist<span class="token punctuation">(</span>samples<span class="token punctuation">,</span> density<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">"normal"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>hist<span class="token punctuation">(</span>censored_samples<span class="token punctuation">,</span> density<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">"censored normal"</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>legend<span class="token punctuation">(</span><span class="token punctuation">)</span> <p>上記のようなデータに対して以下の3つのアプローチを試します。3番目のアプローチは実際には打ち切りされており目的変数の値のみが観測されていないデータなので、Yに対して欠損値があるとして欠損値の問題として切断データをモデル化しています。この方法のメリットとしては、本来は情報を持っていないデータに対しても事後分布から真の値を推定することができることです。</p> <li>切断回帰(truncated regression)</li> <li>打ち切り回帰(censored regression)</li> <li>欠損値を埋また打ち切り回帰(imputed censored regression)</li> <a class="header-anchor-link" href="#%E9%80%9A%E5%B8%B8%E3%81%AE%E7%B7%9A%E5%BD%A2%E5%9B%9E%E5%B8%B0" aria-hidden="true"/> 通常の線形回帰</h3> <div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token keyword">def</span> <span class="token function">normal_regression</span><span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">with</span> pm<span class="token punctuation">.</span>Model<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">as</span> model<span class="token punctuation">:</span> <span class="token comment"># define prior</span> slope <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"slope"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> intercept <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"intercept"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> σ <span class="token operator">=</span> pm<span class="token punctuation">.</span>HalfNormal<span class="token punctuation">(</span><span class="token string">"σ"</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"obs"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span>slope <span class="token operator">*</span> x <span class="token operator">+</span> intercept<span class="token punctuation">,</span> sigma<span class="token operator">=</span>σ<span class="token punctuation">,</span> observed<span class="token operator">=</span>y<span class="token punctuation">)</span> <span class="token keyword">return</span> model normal_model <span class="token operator">=</span> normal_regression<span class="token punctuation">(</span>xt<span class="token punctuation">,</span> yt<span class="token punctuation">)</span> <span class="token keyword">with</span> normal_model<span class="token punctuation">:</span> normal_idata <span class="token operator">=</span> pm<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token number">3000</span><span class="token punctuation">,</span> tune<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">,</span> chains<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> idata_kwargs<span class="token operator">=</span><span class="token punctuation">{</span><span class="token string">"log_likelihood"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span><span class="token punctuation">,</span> random_seed<span class="token operator">=</span><span class="token number">42</span><span class="token punctuation">,</span> target_accept<span class="token operator">=</span><span class="token number">0.90</span><span class="token punctuation">,</span> return_inferencedata<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> az<span class="token punctuation">.</span>summary<span class="token punctuation">(</span>normal_idata<span class="token punctuation">,</span> round_to<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span> <a class="header-anchor-link" href="#%E5%88%87%E6%96%AD%E5%9B%9E%E5%B8%B0%EF%BC%88truncated-regression%EF%BC%89" aria-hidden="true"/> 切断回帰(truncated regression)</h3> <div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token keyword">def</span> <span class="token function">truncated_regression</span><span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">with</span> pm<span class="token punctuation">.</span>Model<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">as</span> model<span class="token punctuation">:</span> <span class="token comment"># define prior</span> slope <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"slope"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> intercept <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"intercept"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> σ <span class="token operator">=</span> pm<span class="token punctuation">.</span>HalfNormal<span class="token punctuation">(</span><span class="token string">"σ"</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> normal_dist <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">.</span>dist<span class="token punctuation">(</span>mu<span class="token operator">=</span>slope <span class="token operator">*</span> x <span class="token operator">+</span> intercept<span class="token punctuation">,</span> sigma<span class="token operator">=</span>σ<span class="token punctuation">)</span> pm<span class="token punctuation">.</span>Truncated<span class="token punctuation">(</span><span class="token string">"obs"</span><span class="token punctuation">,</span> normal_dist<span class="token punctuation">,</span> lower<span class="token operator">=</span>bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> upper<span class="token operator">=</span>bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> observed<span class="token operator">=</span>y<span class="token punctuation">)</span> <span class="token keyword">return</span> model truncated_model <span class="token operator">=</span> truncated_regression<span class="token punctuation">(</span>xt<span class="token punctuation">,</span> yt<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span> <span class="token keyword">with</span> truncated_model<span class="token punctuation">:</span> truncated_idata <span class="token operator">=</span> pm<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token number">3000</span><span class="token punctuation">,</span> tune<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">,</span> chains<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> idata_kwargs<span class="token operator">=</span><span class="token punctuation">{</span><span class="token string">"log_likelihood"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span><span class="token punctuation">,</span> random_seed<span class="token operator">=</span><span class="token number">42</span><span class="token punctuation">,</span> target_accept<span class="token operator">=</span><span class="token number">0.90</span><span class="token punctuation">,</span> return_inferencedata<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> az<span class="token punctuation">.</span>summary<span class="token punctuation">(</span>truncated_idata<span class="token punctuation">,</span> round_to<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span> <a class="header-anchor-link" href="#%E6%89%93%E3%81%A1%E5%88%87%E3%82%8A%E5%9B%9E%E5%B8%B0%EF%BC%88censored-regression%EF%BC%89" aria-hidden="true"/> 打ち切り回帰(censored regression)</h3> <div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token keyword">def</span> <span class="token function">censored_regression</span><span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">with</span> pm<span class="token punctuation">.</span>Model<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">as</span> model<span class="token punctuation">:</span> <span class="token comment"># define prior</span> slope <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"slope"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> intercept <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"intercept"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> σ <span class="token operator">=</span> pm<span class="token punctuation">.</span>HalfNormal<span class="token punctuation">(</span><span class="token string">"σ"</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> y_latent <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">.</span>dist<span class="token punctuation">(</span>mu<span class="token operator">=</span>slope <span class="token operator">*</span> x <span class="token operator">+</span> intercept<span class="token punctuation">,</span> sigma<span class="token operator">=</span>σ<span class="token punctuation">)</span> obs <span class="token operator">=</span> pm<span class="token punctuation">.</span>Censored<span class="token punctuation">(</span><span class="token string">"obs"</span><span class="token punctuation">,</span> y_latent<span class="token punctuation">,</span> lower<span class="token operator">=</span>bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> upper<span class="token operator">=</span>bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> observed<span class="token operator">=</span>y<span class="token punctuation">)</span> <span class="token keyword">return</span> model censored_model <span class="token operator">=</span> censored_regression<span class="token punctuation">(</span>xc<span class="token punctuation">,</span> yc<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span> <span class="token keyword">with</span> censored_model<span class="token punctuation">:</span> censored_idata <span class="token operator">=</span> pm<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token number">3000</span><span class="token punctuation">,</span> tune<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">,</span> chains<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> idata_kwargs<span class="token operator">=</span><span class="token punctuation">{</span><span class="token string">"log_likelihood"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span><span class="token punctuation">,</span> random_seed<span class="token operator">=</span><span class="token number">42</span><span class="token punctuation">,</span> target_accept<span class="token operator">=</span><span class="token number">0.90</span><span class="token punctuation">,</span> return_inferencedata<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> az<span class="token punctuation">.</span>summary<span class="token punctuation">(</span>censored_idata<span class="token punctuation">,</span> round_to<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span> <a class="header-anchor-link" href="#%E6%AC%A0%E6%90%8D%E5%80%A4%E3%82%92%E5%9F%8B%E3%81%BE%E3%81%9F%E6%89%93%E3%81%A1%E5%88%87%E3%82%8A%E5%9B%9E%E5%B8%B0%EF%BC%88imputed-censored-regression%EF%BC%89" aria-hidden="true"/> 欠損値を埋また打ち切り回帰(imputed censored regression)</h3> <div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token comment"># 右側打ち切りされたデータ数</span> n_right_censored <span class="token operator">=</span> <span class="token builtin">sum</span><span class="token punctuation">(</span>yc <span class="token operator">>=</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># 左側打ち切りされたデータ数</span> n_left_censored <span class="token operator">=</span> <span class="token builtin">sum</span><span class="token punctuation">(</span>yc <span class="token operator"><=</span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># 実際の値を観測したデータ数</span> n_observed <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>yc<span class="token punctuation">)</span> <span class="token operator">-</span> n_right_censored <span class="token operator">-</span> n_left_censored <span class="token comment"># 打ち切りされたデータ</span> xc_right_censored <span class="token operator">=</span> xc<span class="token punctuation">[</span>yc <span class="token operator">>=</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">]</span> xc_left_censored <span class="token operator">=</span> xc<span class="token punctuation">[</span>yc <span class="token operator"><=</span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token comment"># 打ち切りされていないデータ</span> xc_uncensored <span class="token operator">=</span> xc<span class="token punctuation">[</span><span class="token punctuation">(</span>yc <span class="token operator">></span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&</span> <span class="token punctuation">(</span>yc <span class="token operator"><</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span> yc_uncensored <span class="token operator">=</span> yc<span class="token punctuation">[</span><span class="token punctuation">(</span>yc <span class="token operator">></span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&</span> <span class="token punctuation">(</span>yc <span class="token operator"><</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span> </code></pre></div><div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token keyword">def</span> <span class="token function">imputed_censored_regression</span><span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span><span class="token punctuation">:</span> low <span class="token operator">=</span> bounds<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> high <span class="token operator">=</span> bounds<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token keyword">with</span> pm<span class="token punctuation">.</span>Model<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">as</span> model<span class="token punctuation">:</span> <span class="token comment"># define prior</span> slope <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"slope"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> intercept <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"intercept"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> σ <span class="token operator">=</span> pm<span class="token punctuation">.</span>HalfNormal<span class="token punctuation">(</span><span class="token string">"σ"</span><span class="token punctuation">,</span> sigma<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># 右側打ち切り分布</span> <span class="token comment"># 右側打ち切りされたデータの数だけ生成</span> <span class="token comment"># 初期値をinitvalで設定</span> right_censored <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span> <span class="token string">"right_censored"</span><span class="token punctuation">,</span> slope <span class="token operator">*</span> xc_right_censored <span class="token operator">+</span> intercept<span class="token punctuation">,</span> transform<span class="token operator">=</span>pm<span class="token punctuation">.</span>distributions<span class="token punctuation">.</span>transforms<span class="token punctuation">.</span>Interval<span class="token punctuation">(</span>high<span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">(</span>n_right_censored<span class="token punctuation">)</span><span class="token punctuation">,</span> initval<span class="token operator">=</span>np<span class="token punctuation">.</span>full<span class="token punctuation">(</span>n_right_censored<span class="token punctuation">,</span> high <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># 左側打ち切り分布</span> <span class="token comment"># 左側打ち切りされたデータの数だけ生成</span> <span class="token comment"># 初期値をinitvalで設定</span> left_censored <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span> <span class="token string">"left_censored"</span><span class="token punctuation">,</span> slope <span class="token operator">*</span> xc_left_censored <span class="token operator">+</span> intercept<span class="token punctuation">,</span> transform<span class="token operator">=</span>pm<span class="token punctuation">.</span>distributions<span class="token punctuation">.</span>transforms<span class="token punctuation">.</span>Interval<span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> low<span class="token punctuation">)</span><span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">(</span>n_left_censored<span class="token punctuation">)</span><span class="token punctuation">,</span> initval<span class="token operator">=</span>np<span class="token punctuation">.</span>full<span class="token punctuation">(</span>n_left_censored<span class="token punctuation">,</span> low <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> obs <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"obs"</span><span class="token punctuation">,</span> mu<span class="token operator">=</span>slope <span class="token operator">*</span> xc_uncensored <span class="token operator">+</span> intercept<span class="token punctuation">,</span> sigma<span class="token operator">=</span>σ<span class="token punctuation">,</span> observed<span class="token operator">=</span>yc_uncensored<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">(</span>n_observed<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">return</span> model imputed_censored_model <span class="token operator">=</span> imputed_censored_regression<span class="token punctuation">(</span>xc<span class="token punctuation">,</span> yc<span class="token punctuation">,</span> bounds<span class="token punctuation">)</span> <span class="token keyword">with</span> imputed_censored_model<span class="token punctuation">:</span> imputed_censored_idata <span class="token operator">=</span> pm<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token number">3000</span><span class="token punctuation">,</span> tune<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">,</span> chains<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> idata_kwargs<span class="token operator">=</span><span class="token punctuation">{</span><span class="token string">"log_likelihood"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span><span class="token punctuation">,</span> random_seed<span class="token operator">=</span><span class="token number">42</span><span class="token punctuation">,</span> target_accept<span class="token operator">=</span><span class="token number">0.90</span><span class="token punctuation">,</span> return_inferencedata<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> az<span class="token punctuation">.</span>summary<span class="token punctuation">(</span>imputed_censored_idata<span class="token punctuation">,</span> round_to<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span> <p><strong>分布の変換</strong></p> <p>distの中でtransformを指定することで分布を変換することができます。以下では、標準正規分布に対して下限を設けています。</p> <div class="code-block-container"><pre class="language-python"><code class="language-python"><span class="token keyword">with</span> pm<span class="token punctuation">.</span>Model<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> interval <span class="token operator">=</span> pm<span class="token punctuation">.</span>distributions<span class="token punctuation">.</span>transforms<span class="token punctuation">.</span>Interval<span class="token punctuation">(</span>lower<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> upper<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span> x <span class="token operator">=</span> pm<span class="token punctuation">.</span>Normal<span class="token punctuation">(</span><span class="token string">"x"</span><span class="token punctuation">,</span> transform<span class="token operator">=</span>interval<span class="token punctuation">)</span> samples <span class="token operator">=</span> pm<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token punctuation">)</span> plt<span class="token punctuation">.</span>hist<span class="token punctuation">(</span>samples<span class="token punctuation">.</span>posterior<span class="token punctuation">.</span>x<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <p><a href="https://www.pymc.io/projects/docs/en/stable/api/distributions/transforms.html" target="_blank" rel="nofollow noopener noreferrer">https://www.pymc.io/projects/docs/en/stable/api/distributions/transforms.html</a></p> <a class="header-anchor-link" href="#%E7%B5%90%E6%9E%9C%E3%81%AE%E6%AF%94%E8%BC%83" aria-hidden="true"/> 結果の比較</h2> <p>真の係数の値は1なのですが、通常の線形回帰ではかなり低めに見積もられていることがわかります。それに対して、他の3つのアプローチでは真の係数の値に近い値を推定できています。また、今回のケースでは、切断回帰より打ち切りデータの方が真の値に近い値を推定できています。これはケースバイケースですが、XとYの両方のデータがない切断データより、打ち切りデータの方は目的変数Yの値のみが欠損しているだけなので、より良い推定値が得られているようです。</p> <div class="code-block-container"><pre class="language-python"><code class="language-python">fig<span class="token punctuation">,</span> ax <span class="token operator">=</span> plt<span class="token punctuation">.</span>subplots<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> figsize<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">10</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">,</span> sharex<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> az<span class="token punctuation">.</span>plot_posterior<span class="token punctuation">(</span>normal_idata<span class="token punctuation">,</span> var_names<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">"slope"</span><span class="token punctuation">]</span><span class="token punctuation">,</span> ref_val<span class="token operator">=</span>slope<span class="token punctuation">,</span> ax<span class="token operator">=</span>ax<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> ax<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">set</span><span class="token punctuation">(</span>title<span class="token operator">=</span><span class="token string">"normal regression\n(truncated data)"</span><span class="token punctuation">,</span> xlabel<span class="token operator">=</span><span class="token string">"slope"</span><span class="token punctuation">)</span> az<span class="token punctuation">.</span>plot_posterior<span class="token punctuation">(</span>truncated_idata<span class="token punctuation">,</span> var_names<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">"slope"</span><span class="token punctuation">]</span><span class="token punctuation">,</span> ref_val<span class="token operator">=</span>slope<span class="token punctuation">,</span> ax<span class="token operator">=</span>ax<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> ax<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">set</span><span class="token punctuation">(</span>title<span class="token operator">=</span><span class="token string">"Truncated regression\n(truncated data)"</span><span class="token punctuation">,</span> xlabel<span class="token operator">=</span><span class="token string">"slope"</span><span class="token punctuation">)</span> az<span class="token punctuation">.</span>plot_posterior<span class="token punctuation">(</span>censored_idata<span class="token punctuation">,</span> var_names<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">"slope"</span><span class="token punctuation">]</span><span class="token punctuation">,</span> ref_val<span class="token operator">=</span>slope<span class="token punctuation">,</span> ax<span class="token operator">=</span>ax<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> ax<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">set</span><span class="token punctuation">(</span>title<span class="token operator">=</span><span class="token string">"Censored regression\n(censored data)"</span><span class="token punctuation">,</span> xlabel<span class="token operator">=</span><span class="token string">"slope"</span><span class="token punctuation">)</span><span class="token punctuation">;</span> az<span class="token punctuation">.</span>plot_posterior<span class="token punctuation">(</span>imputed_censored_idata<span class="token punctuation">,</span> var_names<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">"slope"</span><span class="token punctuation">]</span><span class="token punctuation">,</span> ref_val<span class="token operator">=</span>slope<span class="token punctuation">,</span> ax<span class="token operator">=</span>ax<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span>