The Fool In The Valleyの雑記帳

-- 好奇心いっぱいのおじいちゃんが綴るよしなし事 --

pyplotで3Dグラフを描く

pyplotで3Dグラフを描くためのサンプル・コード

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

def func1(x):
    return x[0]**2 + x[1]**2

x0 = np.arange(-3.0, 3.0, 0.1)
x1 = np.arange(-3.0, 3.0, 0.1)

X0, X1 = np.meshgrid(x0, x1)
y = func1(np.array([X0, X1]))

fig = plt.figure()
ax = Axes3D(fig)

ax.set_xlabel("x0")
ax.set_ylabel("x1")
ax.set_zlabel("y(x0, x1)")

ax.plot_wireframe(X0, X1, y, rstride=2, cstride=2)
plt.savefig("mplot3d_wireframe.png")

ax.clear()
ax.scatter(X0, X1, y, s=1)
plt.savefig("mplot3d_scatter.png")

ax.clear()
ax.plot_surface(X0, X1, y, rstride=1, cstride=1, cmap="plasma")
plt.savefig("mplot3d_surface.png")

実行結果

f:id:tfitv:20210830111850p:plain
mplot3d_wireframe.png
f:id:tfitv:20210830111928p:plain
mplot3d_scatter.png
f:id:tfitv:20210830111954p:plain
mplot3d_surface.png