I think David_sd has a great answer, I'll add how you might adapt his approach to be smoother, or else how you could potentially use plotly
.
import numpy as np
from matplotlib import pyplot as plt
# Prepare data
n = 100
cmap = plt.get_cmap("bwr")
theta = np.linspace(-4 * np.pi, 4 * np.pi, n)
z = np.linspace(-2, 2, n)
r = z**2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)
T = (2 * np.random.rand(n) - 1) # Values in [-1, 1]
If you don't need the colormap, you can get a 3d curve very easily. It's a shame that you can't pass a colormap
argument into the plotting function as you can with the scatterplot/surface plot options you mentioned.
ax = plt.figure().add_subplot(projection='3d')
ax.plot(x, y, z)
To apply a colormap, I'd use the same approach as David_sd (with the limitations you identified).
# Build segments for Line3DCollection
points = np.array([x, y, z]).T.reshape(-1, 1, 3)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
where points is a (100, 1, 3) representation of the points in space:
array([
[[ x1, y1, z1 ]],
[[ x2, y2, z2 ]],
[[ x3, y3, z3 ]],
...
])
and segments is a (99, 2, 3) representation of the point-to-point connections:
array([
[[ x1, y1, z1 ],
[ x2, y2, z2 ]],
[[ x2, y2, z2 ],
[ x3, y3, z3 ]],
...
])
then you run the following, using T[:-1]
to match the shape of segments
.
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib.colors import Normalize
norm = Normalize(vmin=T.min(), vmax=T.max())
colors = cmap(norm(T[:-1])) # Use T[:-1] to match number of segments
lc = Line3DCollection(segments, colors=colors, linewidth=2)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.add_collection3d(lc)
plt.show()
This of course doesn't have the smooth gradient you want. One way to approximate that could be to pump up n
-- e.g. here I set n=1000
If that's not satisfying, I might switch to plotly
, which gets a pretty good gradient even with n=100
import plotly.graph_objects as go
# Create 3D line plot with color
fig = go.Figure(data=go.Scatter3d(
x=x,
y=y,
z=z,
mode='lines',
line=dict(
color=T,
colorscale='RdBu',
cmin=-1,
cmax=1,
width=6
)
))
fig.show()