# 在复现 GitHub 上的 SMT_Tutorial 代码时，发现生成的 3D 图片中点数出现畸形。请问如何解决这个问题？

from future import print_function, division
import numpy as np
from scipy import linalg
from smt.utils.misc import compute_rms_error

from smt.problems import Sphere, NdimRobotArm, Rosenbrock
from smt.sampling_methods import LHS
from smt.surrogate_models import LS, QP, KPLS, KRG, KPLSK, GEKPLS, MGP

#to ignore warning messages
import warnings
warnings.filterwarnings(“ignore”)

try:
from smt.surrogate_models import IDW, RBF, RMTC, RMTB
compiled_available = True
except:
compiled_available = False

try:
import matplotlib.pyplot as plt
plot_status = True
except:
plot_status = False

import scipy.interpolate

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm

########### Initialization of the problem, construction of the training and validation points

ndim = 2
ndoe = 20 #int(10*ndim)

# Define the function

fun = Rosenbrock(ndim=ndim)

# in order to have the always same LHS points, random_state=1

sampling = LHS(xlimits=fun.xlimits, criterion=’ese’, random_state=1)
xt = sampling(ndoe)

yt = fun(xt)

# Construction of the validation points

ntest = 200 #500
sampling = LHS(xlimits=fun.xlimits, criterion=’ese’, random_state=1)
xtest = sampling(ntest)
ytest = fun(xtest)

#To visualize the DOE points
fig = plt.figure(figsize=(10, 10))
plt.scatter(xt[:,0],xt[:,1],marker = ‘x’,c=’b’,s=200,label=’Training points’)
plt.scatter(xtest[:,0],xtest[:,1],marker = ‘.’,c=’k’, s=200, label=’Validation points’)
plt.title(‘DOE’)
plt.xlabel(‘x1’)
plt.ylabel(‘x2’)
plt.legend()
plt.show()

# To plot the Rosenbrock function

x = np.linspace(-2,2,50)
res = []
for x0 in x:
for x1 in x:
res.append(fun(np.array([[x0,x1]])))
res = np.array(res)
res = res.reshape((50,50)).T
X,Y = np.meshgrid(x,x)
fig = plt.figure(figsize=(15, 10))
surf = ax.plot_surface(X, Y, res, cmap=cm.viridis,
linewidth=0, antialiased=False,alpha=0.5)

ax.scatter(xt[:,0],xt[:,1],yt,zdir=’z’,marker = ‘x’,c=’b’,s=200,label=’Training point’)
ax.scatter(xtest[:,0],xtest[:,1],ytest,zdir=’z’,marker = ‘.’,c=’k’,s=200,label=’Validation point’)

plt.title(‘Rosenbrock function’)
plt.xlabel(‘x1’)
plt.ylabel(‘x2’)
plt.legend()
plt.show()

IMO, it is caused by different shapes of x, y, z.

``````>>> # Shape for the x and y
>>> tuple(map(lambda x:x.shape, (xt[:, 0], xt[:, 1], xtest[:, 0], xtest[:, 1])))
((20,), (20,), (200,), (200,))
>>> # Shape for the z
>>> tuple(map(lambda x:x.shape, (yt, ytest)))
((20, 1), (200, 1))
>>> # Revise the shape of z
>>> tuple(map(lambda x:x.shape, (yt[:, 0], ytest[:, 0])))
((20,), (200,))``````

So revise statements from

``````ax.scatter(xt[:,0],xt[:,1],yt,zdir='z',marker = 'x',c='b',s=200,label='Training point')
ax.scatter(xtest[:,0],xtest[:,1],ytest,zdir='z',marker = '.',c='k',s=200,label='Validation point')``````

to

``````ax.scatter(
xt[:, 0], xt[:, 1],yt[:, 0], zdir='z', marker='x', c='b', s=200, label='Training point')
ax.scatter(
xtest[:, 0], xtest[:, 1], ytest[:, 0], zdir='z', marker='.', c='k', s=200, label='Validation point')``````
2周前 评论
dustxan （楼主） 2周前

2周前 评论