在复现 GitHub 上的 SMT_Tutorial 代码时,发现生成的 3D 图片中点数出现畸形。请问如何解决这个问题?

在复现 GitHub 上的 SMT_Tutorial 项目时,2D 图片能够正常生成,但 3D 图片在复现时出现了点数畸形,并且呈现竖条状分布。请问这可能是什么原因,如何解决这个问题?
from future import print_function, division
import numpy as np
from scipy import linalg
from smt.utils.misc import compute_rms_error

from smt.problems import Sphere, NdimRobotArm, Rosenbrock
from smt.sampling_methods import LHS
from smt.surrogate_models import LS, QP, KPLS, KRG, KPLSK, GEKPLS, MGP

#to ignore warning messages
import warnings
warnings.filterwarnings(“ignore”)

try:
from smt.surrogate_models import IDW, RBF, RMTC, RMTB
compiled_available = True
except:
compiled_available = False

try:
import matplotlib.pyplot as plt
plot_status = True
except:
plot_status = False

import scipy.interpolate

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm

########### Initialization of the problem, construction of the training and validation points

ndim = 2
ndoe = 20 #int(10*ndim)

Define the function

fun = Rosenbrock(ndim=ndim)

Construction of the DOE

in order to have the always same LHS points, random_state=1

sampling = LHS(xlimits=fun.xlimits, criterion=’ese’, random_state=1)
xt = sampling(ndoe)

Compute the outputs

yt = fun(xt)

Construction of the validation points

ntest = 200 #500
sampling = LHS(xlimits=fun.xlimits, criterion=’ese’, random_state=1)
xtest = sampling(ntest)
ytest = fun(xtest)

#To visualize the DOE points
fig = plt.figure(figsize=(10, 10))
plt.scatter(xt[:,0],xt[:,1],marker = ‘x’,c=’b’,s=200,label=’Training points’)
plt.scatter(xtest[:,0],xtest[:,1],marker = ‘.’,c=’k’, s=200, label=’Validation points’)
plt.title(‘DOE’)
plt.xlabel(‘x1’)
plt.ylabel(‘x2’)
plt.legend()
plt.show()

To plot the Rosenbrock function

x = np.linspace(-2,2,50)
res = []
for x0 in x:
for x1 in x:
res.append(fun(np.array([[x0,x1]])))
res = np.array(res)
res = res.reshape((50,50)).T
X,Y = np.meshgrid(x,x)
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(projection=’3d’)
surf = ax.plot_surface(X, Y, res, cmap=cm.viridis,
linewidth=0, antialiased=False,alpha=0.5)

ax.scatter(xt[:,0],xt[:,1],yt,zdir=’z’,marker = ‘x’,c=’b’,s=200,label=’Training point’)
ax.scatter(xtest[:,0],xtest[:,1],ytest,zdir=’z’,marker = ‘.’,c=’k’,s=200,label=’Validation point’)

plt.title(‘Rosenbrock function’)
plt.xlabel(‘x1’)
plt.ylabel(‘x2’)
plt.legend()
plt.show()

图1 复现代码畸形图片

图2 项目作者生成图片

项目链接:github.com/SMTorg/smt/blob/master/...

讨论数量: 3
Jason990420

IMO, it is caused by different shapes of x, y, z.

>>> # Shape for the x and y
>>> tuple(map(lambda x:x.shape, (xt[:, 0], xt[:, 1], xtest[:, 0], xtest[:, 1])))
((20,), (20,), (200,), (200,))
>>> # Shape for the z
>>> tuple(map(lambda x:x.shape, (yt, ytest)))
((20, 1), (200, 1))
>>> # Revise the shape of z
>>> tuple(map(lambda x:x.shape, (yt[:, 0], ytest[:, 0])))
((20,), (200,))

So revise statements from

ax.scatter(xt[:,0],xt[:,1],yt,zdir='z',marker = 'x',c='b',s=200,label='Training point')
ax.scatter(xtest[:,0],xtest[:,1],ytest,zdir='z',marker = '.',c='k',s=200,label='Validation point')

to

ax.scatter(
    xt[:, 0], xt[:, 1],yt[:, 0], zdir='z', marker='x', c='b', s=200, label='Training point')
ax.scatter(
    xtest[:, 0], xtest[:, 1], ytest[:, 0], zdir='z', marker='.', c='k', s=200, label='Validation point')
1个月前 评论
dustxan (楼主) 1个月前

在三维图形中,需要用到 x、y 和 z 三个坐标轴来表示数据点的位置。每一个数据点都需要有这三个坐标的值。在绘制三维散点图时,我们需要确保 x、y 和 z 坐标的数据长度一致。 问题的根源 遇到的问题是因为 z 坐标的数据形状与 x 和 y 坐标的数据形状不一致。具体来说,yt 和 ytest 的数据形状是 (20, 1) 和 (200, 1),而 xt[:, 0] 和 xt[:, 1] 的数据形状是 (20,) 和 (200,)。这导致了绘图函数无法正确处理这些数据,无法绘制出想要的图形。 需要将 yt 和 ytest 转换成一维数组。具体来说,我们需要从 yt 和 ytest 中提取出它们的第一个元素,将它们转换成一维数组。这可以通过 [:, 0] 实现。 这样,就能确保 x、y 和 z 坐标的数据长度一致,绘图函数就能正常工作了。

1个月前 评论

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!