Now I got it :) Thanks @jared, pcolormesh
was the right function, but I have to explicitly map the colors as the plotted variable:
import numpy as np
from matplotlib import pyplot as plt
axes = (np.linspace(-2, 2, 100), np.linspace(-2, 2, 100))
xx, yy = np.meshgrid(*axes, indexing="xy")
fig, ax = plt.subplots()
z = np.abs(xx * yy).astype(int) # values 0, 1, 2, 3, 4
z[z==0] = 4
cmap = plt.get_cmap("Set1")
z_color = cmap(z) # shape (100, 100, 4) with `z` as index
ax.pcolormesh(xx, yy, z_color)