Why grid becomes inappropriate in 3D projection after assigning x and yticks?

55 views Asked by At

Using the following code grid doesn't appear properly as it can be seen in the figure that the squares in the xy plane have large size as compared to z axis. how can I get the grid of equal boxes? I have observed the grid squares are equal when I don't apply x and yticks. Can I fix this error?

fig = plt.figure()
ax = fig.add_subplot(111)
x= np.linspace(-math.pi,math.pi,30)
y= np.linspace(-math.pi,math.pi,30)
xx,yy = np.meshgrid(x,y)

X_grid, Y_grid = np.meshgrid(x,y)
print("X_grid shape: ",X_grid.shape)
zz =-2*(np.cos(xx) + np.cos(yy))

ax = plt.axes(projection='3d')

ax.plot_surface(xx,yy,zz)

plt.xticks([x[0],x[int(len(s)/2)], x[-1]], [r'$-\pi$',  r'$0$', r'$-\pi$' ])
plt.yticks([y[0],y[int(len(s)/2)], y[-1]], [r'$-\pi$',  r'$0$', r'$-\pi$' ])
plt.xlabel("kx")
plt.ylabel("ky")

enter image description here

1

There are 1 answers

8
Rex5 On

Referring to the documentation I found this:

Note: Prior to version 1.0.0, the method of creating a 3D axes was different. For those using older versions of matplotlib, change ax = fig.add_subplot(111, projection='3d') to ax = Axes3D(fig).

See this for .set_zticks. Note that since this is 3D plot so this documentation is found in mplot3d

So modified the code accordingly.

import pandas as pd 
import numpy as np 
import math
import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111)
x= np.linspace(-math.pi,math.pi,30)
y= np.linspace(-math.pi,math.pi,30)
z= np.linspace(-math.pi,math.pi,30)

xx,yy = np.meshgrid(x,y)

X_grid, Y_grid = np.meshgrid(x,y)
print("X_grid shape: ",X_grid.shape)
zz =-2*(np.cos(X_grid) + np.cos(Y_grid))

Z = zz.reshape(X_grid.shape)

ax = Axes3D(fig)

#updated lines
ax.set_zticks([z[0],z[int(len(z)/2)], z[-1]])
ax.set_zticklabels([r'$-\pi$',  r'$0$', r'$-\pi$' ])

plt.xticks([x[0],x[int(len(x)/2)], x[-1]], [r'$-\pi$',  r'$0$', r'$-\pi$' ])
plt.yticks([y[0],y[int(len(y)/2)], y[-1]], [r'$-\pi$',  r'$0$', r'$-\pi$' ])

plt.xlabel("kx")
plt.ylabel("ky")    
ax.plot_surface(X_grid,Y_grid,Z)