This blog post allowed me to figure out the answer, so credit mostly goes there and I would recommended you read it in full.
The key point is that you need to include the shape
argument in your definition of y_obs
:
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, shape=x_shared.shape, observed=y_train)
Note the shape=x_shared.shape
!
I also saw in multiple examples (the blog and the pymc docs) that you should do
post_pred_test = pm.sample_posterior_predictive(trace, predictions=True)
i.e. set predictions=True
- not sure if this makes a big difference but it seems like the right thing to do...
Your code has some other issues, at least with my version of pymc:
You cannot do post_pred_train["y_obs"].mean(axis=0),
but I got it working with
plt.plot(
x_train,
post_pred_train.posterior_predictive["y_obs"].mean(("chain", "draw")),
label="Posterior predictive (train)",
color="red",
)
and similarly (but confusingly slightly different), for the test data:
plt.plot(
x_test,
post_pred_test.predictions["y_obs"].mean(("chain", "draw")),
label="Posterior predictive (test)",
color="orange",
)
Note: .predictions
here instead of .posterior_predictive
And in both cases, I needed to take the mean across chain and draw