Coursera Algorithms Programming Assignment 5: Kd-Trees (98分)

题目地址:http://coursera.cs.princeton.edu/algs4/assignments/kdtree.html

分析:

Brute-force implementation. 蛮力实现的方法比较简单,就是逐个遍历每个point进行比较,实现下述API就可以了,没有什么难度。

 1 import java.util.ArrayList;
 2 import java.util.TreeSet;
 3 import edu.princeton.cs.algs4.Point2D;
 4 import edu.princeton.cs.algs4.RectHV;
 5 import edu.princeton.cs.algs4.StdDraw;
 6 /**
 7  * @author evasean www.cnblogs.com/evasean/
 8  */
 9 public class PointSET {
10     private TreeSet<Point2D> points;
11     public PointSET() {
12         // construct an empty set of points
13         points = new TreeSet<Point2D>();
14     }
15 
16     public boolean isEmpty() {
17         // is the set empty?
18         return points.isEmpty();
19     }
20 
21     public int size() {
22         // number of points in the set
23         return points.size();
24     }
25 
26     public void insert(Point2D p) {
27         // add the point to the set (if it is not already in the set)
28         if(p==null)
29             throw new IllegalArgumentException("Point2D p is not illegal!");
30         if(!points.contains(p))
31             points.add(p);
32     }
33 
34     public boolean contains(Point2D p) {
35         // does the set contain point p?
36         if(p==null)
37             throw new IllegalArgumentException("Point2D p is not illegal!");
38         return points.contains(p);
39     }
40 
41     public void draw() {
42         // draw all points to standard draw
43         for (Point2D p : points) {
44             p.draw();
45         }
46         StdDraw.show();
47     }
48 
49     public Iterable<Point2D> range(RectHV rect) {
50         // all points that are inside the rectangle (or on the boundary)
51         if(rect==null)
52             throw new IllegalArgumentException("RectHV rect is not illegal!");
53         ArrayList<Point2D> list = new ArrayList<Point2D>();
54         for(Point2D point : points){
55             if(rect.contains(point)) list.add(point);
56         }
57         return list;
58     }
59 
60     public Point2D nearest(Point2D p) {
61         // a nearest neighbor in the set to point p; null if the set is empty
62         if(p==null)
63             throw new IllegalArgumentException("Point2D p is not illegal!");
64         if(points.size() == 0) return null;
65         double neareatDistance = Double.POSITIVE_INFINITY;
66         Point2D nearest = null;
67         for(Point2D point : points){
68             double tmp = p.distanceTo(point);
69             if(Double.compare(neareatDistance, tmp) == 1){
70                 neareatDistance = tmp;
71                 nearest = point;
72             }
73                 
74         }
75         return nearest;
76     }
77 
78     public static void main(String[] args) {
79         // unit testing of the methods (optional)
80     }
81 }

2d-tree implementation.

kd-tree插入时要交替以x坐标和y坐标作为判断依据,比如root节点处比较依据为x坐标,那么当要查找或插入一个新节点point时,比较root节点的x坐标和point的x坐标,如果后者比前者小,那么下一次要比较的就是root->left, 相反下一次要比较的就是root->right。进入下一层级之后,就要使用y坐标作为比较依据。示例如下图:

 

区域搜索:查找落在给定矩阵区域范围内的所有points。从root开始递归查找,如果给定的矩阵不与当前节点的相关矩阵相交,那么就没有必要继续查找该节点及其子树了。

最近节点搜索:查找与给定point距离最近的节点。从root开始递归查找其左右子树,如果给定节点point和已经查找到的最近节点的距离比该point与当前遍历节点的相关矩阵距离还近,那么就没必要遍历这个当前节点及其子树了。

  1 import java.util.ArrayList;
  2 import edu.princeton.cs.algs4.Point2D;
  3 import edu.princeton.cs.algs4.RectHV;
  4 import edu.princeton.cs.algs4.StdDraw;
  5 /**
  6  * @author evasean www.cnblogs.com/evasean/
  7  */
  8 public class KdTree {
  9     private Node root;
 10     private class Node {
 11         private Point2D p;
 12         /*
 13          * 节点的value就是包含该节点的矩形空间 其左右子树的矩形空间,就是通过该节点进行水平切分或垂直切分的两个子空间
 14          */
 15         private RectHV rect;
 16         private Node left, right;
 17         private int size;
 18         private boolean xCoordinate; // 标识该节点是否是以x坐标垂直切分
 19 
 20         public Node(Point2D point, RectHV rect, int size, boolean xCoordinate) {
 21             this.p = point;
 22             this.rect = rect;
 23             this.size = size;
 24             this.xCoordinate = xCoordinate;
 25         }
 26     }
 27 
 28     public KdTree() {
 29         // construct an empty set of points
 30     }
 31 
 32     public boolean isEmpty() {
 33         // is the set empty?
 34         return size() == 0;
 35     }
 36 
 37     public int size() {
 38         // number of points in the set
 39         return size(root);
 40     }
 41 
 42     private int size(Node x) {
 43         if (x == null)
 44             return 0;
 45         else
 46             return x.size;
 47     }
 48 
 49     public void insert(Point2D p) {
 50         // add the point to the set (if it is not already in the set)
 51         if (p == null)
 52             throw new IllegalArgumentException("Point2D p is not illegal!");
 53         if (root == null)
 54             root = new Node(p, new RectHV(0.0, 0.0, 1.0, 1.0), 1, true);
 55         else
 56             insert(root, p);
 57         // System.out.println("size="+root.size);
 58     }
 59 
 60     private void insert(Node x, Point2D p) {
 61         if (x.xCoordinate == true) { // x的切分标志是x坐标
 62             int cmp = Double.compare(p.x(), x.p.x());
 63             if (cmp == -1) {
 64                 if (x.left != null)
 65                     insert(x.left, p);
 66                 else {
 67                     RectHV parent = x.rect;
 68                     // 将节点x的矩形空间进行垂直切分后的左侧部分
 69                     double newXmin = parent.xmin();
 70                     double newYmin = parent.ymin();
 71                     double newXmax = x.p.x();
 72                     double newYmax = parent.ymax();
 73                     x.left = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, false);
 74                 }
 75             } else if (cmp == 1) {
 76                 if (x.right != null)
 77                     insert(x.right, p);
 78                 else {
 79                     RectHV parent = x.rect;
 80                     // 将节点x的矩形空间进行垂直切分后的右侧部分
 81                     double newXmin = x.p.x();
 82                     double newYmin = parent.ymin();
 83                     double newXmax = parent.xmax();
 84                     double newYmax = parent.ymax();
 85                     x.right = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, false);
 86                 }
 87             } else { // x.key.x() 与 p.x() 相等
 88                 int cmp2 = Double.compare(p.y(), x.p.y());
 89                 if (cmp2 == -1) {
 90                     if (x.left != null)
 91                         insert(x.left, p);
 92                     else {
 93                         x.left = new Node(p, x.rect, 1, false);
 94                     }
 95                 } else if (cmp2 == 1) {
 96                     if (x.right != null)
 97                         insert(x.right, p);
 98                     else {
 99                         x.right = new Node(p, x.rect, 1, false);
100                     }
101                 }
102             }
103         } else { // x的切分标志是y坐标
104             int cmp = Double.compare(p.y(), x.p.y());
105             if (cmp == -1) {
106                 if (x.left != null)
107                     insert(x.left, p);
108                 else {
109                     RectHV parent = x.rect;
110                     // 将节点x的矩形空间进行垂直切分后的左侧部分
111                     double newXmin = parent.xmin();
112                     double newYmin = parent.ymin();
113                     double newXmax = parent.xmax();
114                     double newYmax = x.p.y();
115                     x.left = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, true);
116                 }
117             } else if (cmp == 1) {
118                 if (x.right != null)
119                     insert(x.right, p);
120                 else {
121                     RectHV parent = x.rect;
122                     // 将节点x的矩形空间进行垂直切分后的左侧部分
123                     double newXmin = parent.xmin();
124                     double newYmin = x.p.y();
125                     double newXmax = parent.xmax();
126                     double newYmax = parent.ymax();
127                     x.right = new Node(p, new RectHV(newXmin, newYmin, newXmax, newYmax), 1, true);
128                 }
129             } else { // x.key.y() 与 p.y()相等
130                 int cmp2 = Double.compare(p.x(), x.p.x());
131                 if (cmp2 == -1) {
132                     if (x.left != null)
133                         insert(x.left, p);
134                     else {
135                         x.left = new Node(p, x.rect, 1, true);
136                     }
137                 } else if (cmp2 == 1) {
138                     if (x.right != null)
139                         insert(x.right, p);
140                     else {
141                         x.right = new Node(p, x.rect, 1, true);
142                     }
143                 }
144             }
145         }
146         x.size = 1 + size(x.left) + size(x.right);
147     }
148 
149     public boolean contains(Point2D p) {
150         // does the set contain point p?
151         if (p == null)
152             throw new IllegalArgumentException("Point2D p is not illegal!");
153         return contains(root, p);
154     }
155 
156     private boolean contains(Node x, Point2D p) {
157         if(x == null ) return false;
158         if (x.p.equals(p))
159             return true;
160         else {
161             if(x.xCoordinate == true){
162                 int cmp = Double.compare(p.x(), x.p.x());
163                 if(cmp == -1) return contains(x.left,p);
164                 else if(cmp == 1 ) return contains(x.right,p);
165                 else{
166                     int cmp2 = Double.compare(p.y(), x.p.y());
167                     if(cmp2 == -1) return contains(x.left,p);
168                     else if(cmp2 == 1 ) return contains(x.right,p);
169                     else return true;
170                 }
171             }else{
172                 int cmp = Double.compare(p.y(), x.p.y());
173                 if(cmp == -1) return contains(x.left,p);
174                 else if(cmp == 1 ) return contains(x.right,p);
175                 else{
176                     int cmp2 = Double.compare(p.x(), x.p.x());
177                     if(cmp2 == -1) return contains(x.left,p);
178                     else if(cmp2 == 1 ) return contains(x.right,p);
179                     else return true;
180                 }    
181             }
182         }
183     }
184 
185     public void draw() {
186         // draw all points to standard draw
187         StdDraw.setXscale(0, 1);
188         StdDraw.setYscale(0, 1);
189         draw(root);
190     }
191 
192     private void draw(Node x) {
193         if (x == null)
194             return;
195         StdDraw.setPenColor(StdDraw.BLACK);
196         StdDraw.setPenRadius(0.01);
197         x.p.draw();
198         if (x.xCoordinate == true) {
199             StdDraw.setPenColor(StdDraw.RED);
200             StdDraw.setPenRadius();
201             Point2D start = new Point2D(x.p.x(), x.rect.ymin());
202             Point2D end = new Point2D(x.p.x(), x.rect.ymax());
203             start.drawTo(end);
204         } else {
205             StdDraw.setPenColor(StdDraw.BLUE);
206             StdDraw.setPenRadius();
207             Point2D start = new Point2D(x.rect.xmin(), x.p.y());
208             Point2D end = new Point2D(x.rect.xmax(), x.p.y());
209             start.drawTo(end);
210         }
211         draw(x.left);
212         draw(x.right);
213     }
214 
215     public Iterable<Point2D> range(RectHV rect) {
216         // all points that are inside the rectangle (or on the boundary)
217         if (rect == null)
218             throw new IllegalArgumentException("RectHV rect is not illegal!");
219         if (root != null)
220             return range(root, rect);
221         else
222             return new ArrayList<Point2D>();
223     }
224 
225     private ArrayList<Point2D> range(Node x, RectHV rect) {
226         ArrayList<Point2D> points = new ArrayList<Point2D>();
227         if (x.rect.intersects(rect)) {
228             if (rect.contains(x.p))
229                 points.add(x.p);
230             if (x.left != null)
231                 points.addAll(range(x.left, rect));
232             if (x.right != null)
233                 points.addAll(range(x.right, rect));
234         }
235         return points;
236     }
237 
238     public Point2D nearest(Point2D p) {
239         // a nearest neighbor in the set to point p; null if the set is empty
240         if (p == null)
241             throw new IllegalArgumentException("Point2D p is not illegal!");
242         if (root != null)
243             return nearest(root, p, root.p);
244         return null;
245     }
246 
247     /**
248      * 作业提交提示nearest的时间复杂度偏高,导致作业只有98分,我觉得这样写比较清晰明了,就懒得继续优化
249      * @param x
250      * @param p
251      * @param currNearPoint
252      * @return
253      */
254     private Point2D nearest(Node x, Point2D p, Point2D currNearPoint) {
255         if(x.p.equals(p)) return x.p;
256         double currMinDistance = currNearPoint.distanceTo(p);
257         if (Double.compare(x.rect.distanceTo(p), currMinDistance) >= 0)
258             return currNearPoint;
259         else {
260             double distance = x.p.distanceTo(p);
261             if (Double.compare(x.p.distanceTo(p), currMinDistance) == -1) {
262                 currNearPoint = x.p;
263                 currMinDistance = distance;
264             }
265             if (x.left != null)
266                 currNearPoint = nearest(x.left, p, currNearPoint);
267             if (x.right != null)
268                 currNearPoint = nearest(x.right, p, currNearPoint);
269         }
270         return currNearPoint;
271     }
272 
273     public static void main(String[] args) {
274         // unit testing of the methods (optional)
275     }
276 }
原文地址:https://www.cnblogs.com/evasean/p/7367853.html