i hope its what is required.
import matplotlib.pyplot as plt
import numpy as np
z1 = np.linspace(0, 1.4, 20)
z2 = np.linspace(0, 1.1, 20)
z3 = np.linspace(0, 1.0, 20)
fig, axes = plt.subplots(1, 3)
axes[0].plot(np.exp(-z1), z1, c='k')
axes[1].plot(np.exp(-z2), z2, c='k')
axes[2].plot(np.exp(-z3), z3, c='k')
orig_pos = [ax.get_position() for ax in axes]
max_y_vals = [np.max(z1),np.max(z2),np.max(z3)]
max_y = max(max_y_vals)
height_ratios = [v / max_y for v in max_y_vals]
# Align bottoms
for ax, pos, h in zip(axes, orig_pos, height_ratios):
new_height = (pos.y1 - pos.y0) * h
new_y0 = pos.y0 # keep bottom fixed
new_y1 = new_y0 + new_height
ax.set_position([pos.x0, new_y0, pos.width, new_height])
plt.show()