pyplotで3Dグラフを描くためのサンプル・コード
from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import numpy as np def func1(x): return x[0]**2 + x[1]**2 x0 = np.arange(-3.0, 3.0, 0.1) x1 = np.arange(-3.0, 3.0, 0.1) X0, X1 = np.meshgrid(x0, x1) y = func1(np.array([X0, X1])) fig = plt.figure() ax = Axes3D(fig) ax.set_xlabel("x0") ax.set_ylabel("x1") ax.set_zlabel("y(x0, x1)") ax.plot_wireframe(X0, X1, y, rstride=2, cstride=2) plt.savefig("mplot3d_wireframe.png") ax.clear() ax.scatter(X0, X1, y, s=1) plt.savefig("mplot3d_scatter.png") ax.clear() ax.plot_surface(X0, X1, y, rstride=1, cstride=1, cmap="plasma") plt.savefig("mplot3d_surface.png")
実行結果