Chapter 7. Ulysses’ Compass
with numpyro.handlers.seed(rng_seed=71): # number of plants N = 100 # simulate initial heights h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N])) # assign treatments and simulate fungus and growth treatment = jnp.repeat(jnp.arange(2), repeats=N // 2) fungus = numpyro.sample( "fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4)) ) h1 = h0 + numpyro.sample("diff", dist.Normal(5 - 3 * fungus)) # compose a clean data frame d = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus}) def model(h0, h1): p = numpyro.sample("p", dist.LogNormal(0, 0.25)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = h0 * p numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1) m6_6 = AutoLaplaceApproximation(model) svi = SVI(model, m6_6, optim.Adam(0.1), Trace_ELBO(), h0=d.h0.values, h1=d.h1.values) svi_result = svi.run(random.PRNGKey(0), 1000) p6_6 = svi_result.params def model(treatment, fungus, h0, h1): a = numpyro.sample("a", dist.LogNormal(0, 0.2)) bt = numpyro.sample("bt", dist.Normal(0, 0.5)) bf = numpyro.sample("bf", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) p = a + bt * treatment + bf * fungus mu = h0 * p numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1) m6_7 = AutoLaplaceApproximation(model) svi = SVI( model, m6_7, optim.Adam(0.3), Trace_ELBO(), treatment=d.treatment.values, fungus=d.fungus.values, h0=d.h0.values, h1=d.h1.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p6_7 = svi_result.params def model(treatment, h0, h1): a = numpyro.sample("a", dist.LogNormal(0, 0.2)) bt = numpyro.sample("bt", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) p = a + bt * treatment mu = h0 * p numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1) m6_8 = AutoLaplaceApproximation(model) svi = SVI( model, m6_8, optim.Adam(1), Trace_ELBO(), treatment=d.treatment.values, h0=d.h0.values, h1=d.h1.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p6_8 = svi_result.params post = m6_7.sample_posterior(random.PRNGKey(11), p6_7, sample_shape=(1000,)) logprob = log_likelihood( m6_7.model, post, treatment=d.treatment.values, fungus=d.fungus.values, h0=d.h0.values, h1=d.h1.values, ) az6_7 = az.from_dict(sample_stats={"log_likelihood": logprob["h1"][None, ...]}) az.waic(az6_7, scale="deviance")