我想知道在两组轮廓线之间找到所有交点(到舍入误差)的最佳方法.这是最好的方法吗?这是一个例子:
import matplotlib.pyplot as plt import numpy as np x = np.linspace(-1,1,500) X,Y = np.meshgrid(x,x) Z1 = np.abs(np.sin(2*X**2+Y)) Z2 = np.abs(np.cos(2*Y**2+X**2)) plt.contour(Z1,colors='k') plt.contour(Z2,colors='r') plt.show()
我想要一些类似的:
intersection_points = intersect(contour1,contour2) print intersection_points [(x1,y1),...,(xn,yn)]
解决方法
import collections import matplotlib.pyplot as plt import numpy as np import scipy.spatial as spatial import scipy.spatial.distance as dist import scipy.cluster.hierarchy as hier def intersection(points1,points2,eps): tree = spatial.KDTree(points1) distances,indices = tree.query(points2,k=1,distance_upper_bound=eps) intersection_points = tree.data[indices[np.isfinite(distances)]] return intersection_points def cluster(points,cluster_size): dists = dist.pdist(points,metric='sqeuclidean') linkage_matrix = hier.linkage(dists,'average') groups = hier.fcluster(linkage_matrix,cluster_size,criterion='distance') return np.array([points[cluster].mean(axis=0) for cluster in clusterlists(groups)]) def contour_points(contour,steps=1): return np.row_stack([path.interpolated(steps).vertices for linecol in contour.collections for path in linecol.get_paths()]) def clusterlists(T): ''' https://stackoverflow.com/a/2913071/190597 (denis) T = [2,2,1] Returns [[0,4,5,6,7,8],[1,3,9]] ''' groups = collections.defaultdict(list) for i,elt in enumerate(T): groups[elt].append(i) return sorted(groups.values(),key=len,reverse=True) # every intersection point must be within eps of a point on the other # contour path eps = 1.0 # cluster together intersection points so that the original points in each flat # cluster have a cophenetic_distance < cluster_size cluster_size = 100 x = np.linspace(-1,x) Z1 = np.abs(np.sin(2 * X ** 2 + Y)) Z2 = np.abs(np.cos(2 * Y ** 2 + X ** 2)) contour1 = plt.contour(Z1,colors='k') contour2 = plt.contour(Z2,colors='r') points1 = contour_points(contour1) points2 = contour_points(contour2) intersection_points = intersection(points1,eps) intersection_points = cluster(intersection_points,cluster_size) plt.scatter(intersection_points[:,0],intersection_points[:,1],s=20) plt.show()
产量