《算法》第二章部分程序 part 2

▶ 书中第二章部分程序,加上自己补充的代码,包括若干种归并排序,以及利用归并排序计算数组逆序数

 ● 归并排序

  1 package package01;
  2 
  3 import java.util.Comparator;
  4 import edu.princeton.cs.algs4.StdIn;
  5 import edu.princeton.cs.algs4.StdOut;
  6 
  7 public class class01
  8 {
  9     private class01() {}
 10 
 11     private static void merge(Comparable[] a, Comparable[] aux, int lo, int mid, int hi)    // 归并两个排好序的子数组
 12     {
 13         for (int k = lo; k <= hi; k++)
 14             aux[k] = a[k];
 15 
 16         int i = lo, j = mid + 1;
 17         for (int k = lo; k <= hi; k++)
 18         {
 19             if (i > mid)                    // 后段剩余
 20                 a[k] = aux[j++];
 21             else if (j > hi)                // 前段剩余
 22                 a[k] = aux[i++];
 23             else if (less(aux[j], aux[i]))  // 比较
 24                 a[k] = aux[j++];
 25             else
 26                 a[k] = aux[i++];
 27         }
 28     }
 29 
 30     private static void sortTDKernel(Comparable[] a, Comparable[] aux, int lo, int hi)      // 排序递归内核
 31     {
 32         if (hi <= lo)
 33             return;
 34         int mid = lo + (hi - lo) / 2;
 35         sortTDKernel(a, aux, lo, mid);
 36         sortTDKernel(a, aux, mid + 1, hi);
 37         merge(a, aux, lo, mid, hi);
 38     }
 39 
 40     public static void sortTD(Comparable[] a)                                               // 自顶向下的归并排序
 41     {
 42         Comparable[] aux = new Comparable[a.length];    // 统一分配临时内存
 43         sortTDKernel(a, aux, 0, a.length - 1);
 44     }
 45 
 46     private static void indexMerge(Comparable[] a, int[] index, int[] aux, int lo, int mid, int hi)  // 间接排序的归并
 47     {
 48         for (int k = lo; k <= hi; k++)
 49             aux[k] = index[k];
 50 
 51         int i = lo, j = mid + 1;
 52         for (int k = lo; k <= hi; k++)
 53         {
 54             if (i > mid)
 55                 index[k] = aux[j++];
 56             else if (j > hi)
 57                 index[k] = aux[i++];
 58             else if (less(a[aux[j]], a[aux[i]]))
 59                 index[k] = aux[j++];
 60             else
 61                 index[k] = aux[i++];
 62         }
 63     }
 64 
 65     private static void indexSortTDKernel(Comparable[] a, int[] index, int[] aux, int lo, int hi)   // 间接排序递归内核
 66     {
 67         if (hi <= lo)
 68             return;
 69         int mid = lo + (hi - lo) / 2;
 70         indexSortTDKernel(a, index, aux, lo, mid);
 71         indexSortTDKernel(a, index, aux, mid + 1, hi);
 72         indexMerge(a, index, aux, lo, mid, hi);
 73     }
 74 
 75     public static int[] indexSortTD(Comparable[] a)                                                 // 自顶向下的间接归并排序
 76     {
 77         int n = a.length;
 78         int[] aux = new int[n];
 79         int[] index = new int[n];
 80         for (int i = 0; i < n; i++)
 81             index[i] = i;
 82 
 83         indexSortTDKernel(a, index, aux, 0, n - 1);
 84         return index;
 85     }
 86 
 87     public static void sortBU(Comparable[] a)                   // 自底向上的归并排序,不需要递归,合并子数组的函数与前面相同
 88     {
 89         int n = a.length;
 90         Comparable[] aux = new Comparable[n];
 91         for (int len = 1; len < n; len *= 2)
 92         {
 93             for (int lo = 0; lo < n - len; lo += len + len)
 94             {
 95                 int mid = lo + len - 1;
 96                 int hi = Math.min(lo + len + len - 1, n - 1);
 97                 merge(a, aux, lo, mid, hi);
 98             }
 99         }
100     }
101 
102     // 改良版本
103     private static final int CUTOFF = 7;    // 小于该尺寸的数组使用插入排序
104 
105     private static void merge2(Comparable[] src, Comparable[] dst, int lo, int mid, int hi) // 区分原数组和目标数组,减少拷贝
106     {
107         int i = lo, j = mid + 1;
108         for (int k = lo; k <= hi; k++)
109         {
110             if (i > mid)
111                 dst[k] = src[j++];
112             else if (j > hi)
113                 dst[k] = src[i++];
114             else if (less(src[j], src[i]))
115                 dst[k] = src[j++];
116             else
117                 dst[k] = src[i++];
118         }
119     }
120 
121     private static void sort2TDKernel(Comparable[] src, Comparable[] dst, int lo, int hi)   
122     {
123         if (hi <= lo + CUTOFF)                                  // 数据较少时使用插入排序
124         {
125             insertionSort(dst, lo, hi);
126             return;
127         }
128         int mid = lo + (hi - lo) / 2;
129         sort2TDKernel(dst, src, lo, mid);
130         sort2TDKernel(dst, src, mid + 1, hi);
131 
132         if (!less(src[mid + 1], src[mid]))                      // src[mid+1] >= src[mid],不用归并了
133             System.arraycopy(src, lo, dst, lo, hi - lo + 1);    // 数组拷贝,快于 for (int i = lo; i <= hi; i++) dst[i] = src[i];
134         else
135             merge2(src, dst, lo, mid, hi);
136     }
137 
138     public static void sort2TD(Comparable[] a)
139     {
140         Comparable[] aux = a.clone();
141         sort2TDKernel(aux, a, 0, a.length - 1);
142     }
143 
144     private static void insertionSort(Comparable[] a, int lo, int hi)
145     {
146         for (int i = lo; i <= hi; i++)
147         {
148             for (int j = i; j > lo && less(a[j], a[j - 1]); j--)
149                 exch(a, j, j - 1);
150         }
151     }
152 
153     private static void exch(Comparable[] a, int i, int j)      // 插入排序用到的交换
154     {
155         Comparable swap = a[i];
156         a[i] = a[j];
157         a[j] = swap;
158     }
159 
160     private static void merge2(Object[] src, Object[] dst, int lo, int mid, int hi, Comparator comparator)  // 自定义类型的版本(同上 5 个函数)
161     {
162         int i = lo, j = mid + 1;
163         for (int k = lo; k <= hi; k++)
164         {
165             if (i > mid)
166                 dst[k] = src[j++];
167             else if (j > hi)
168                 dst[k] = src[i++];
169             else if (less(src[j], src[i], comparator))
170                 dst[k] = src[j++];
171             else
172                 dst[k] = src[i++];
173         }
174     }
175 
176     private static void sort2TDKernel(Object[] src, Object[] dst, int lo, int hi, Comparator comparator)
177     {
178         if (hi <= lo + CUTOFF)
179         {
180             insertionSort(dst, lo, hi, comparator);
181             return;
182         }
183 
184         int mid = lo + (hi - lo) / 2;
185         sort2TDKernel(dst, src, lo, mid, comparator);
186         sort2TDKernel(dst, src, mid + 1, hi, comparator);
187 
188         if (!less(src[mid + 1], src[mid], comparator))
189             System.arraycopy(src, lo, dst, lo, hi - lo + 1);
190         else
191             merge2(src, dst, lo, mid, hi, comparator);
192     }
193 
194     public static void sort2TD(Object[] a, Comparator comparator)
195     {
196         Object[] aux = a.clone();
197         sort2TDKernel(aux, a, 0, a.length - 1, comparator);
198     }
199 
200     private static void insertionSort(Object[] a, int lo, int hi, Comparator comparator)
201     {
202         for (int i = lo; i <= hi; i++)
203         {
204             for (int j = i; j > lo && less(a[j], a[j - 1], comparator); j--)
205                 exch(a, j, j - 1);
206         }
207     }
208 
209     private static void exch(Object[] a, int i, int j)
210     {
211         Object swap = a[i];
212         a[i] = a[j];
213         a[j] = swap;
214     }
215 
216     private static boolean less(Comparable a, Comparable b)                 // 各排序都用到的比较函数
217     {
218         return a.compareTo(b) < 0;
219     }
220 
221     private static boolean less(Object a, Object b, Comparator comparator)  // 自定义类型的比较函数
222     {
223         return comparator.compare(a, b) < 0;
224     }
225 
226     private static void show(Comparable[] a)
227     {
228         for (int i = 0; i < a.length; i++)
229             StdOut.println(a[i]);
230     }
231 
232     public static void main(String[] args)
233     {
234         String[] a = StdIn.readAllStrings();
235         //int[] index = class01.indexSortTD(a);
236 
237         class01.sortTD(a);
238         //class01.sortBU(a);
239         //class01.sort2TD(a);
240         //for (int i = 0; i<a.length; i++)
241         //    StdOut.println(index[i]);
242         System.out.printf("
Finish!
");
243     }
244 }

● 利用归并排序来计算数组的逆序数,只注释了与归并排序不一样的地方

  1 package package01;
  2 
  3 import edu.princeton.cs.algs4.In;
  4 import edu.princeton.cs.algs4.StdOut;
  5 
  6 public class class01
  7 {
  8     private class01() {}
  9 
 10     private static long merge(int[] a, int[] aux, int lo, int mid, int hi)  // 限定输入为 int 数组
 11     {
 12         long inversions = 0;
 13 
 14         for (int k = lo; k <= hi; k++)
 15             aux[k] = a[k];
 16 
 17         int i = lo, j = mid + 1;
 18         for (int k = lo; k <= hi; k++)
 19         {
 20             if (i > mid)
 21                 a[k] = aux[j++];
 22             else if (j > hi)
 23                 a[k] = aux[i++];
 24             else if (aux[j] < aux[i])                // 算术比较
 25             {
 26                 a[k] = aux[j++];
 27                 inversions += (mid - i + 1);        // 多了一步计算
 28             }
 29             else
 30                 a[k] = aux[i++];
 31         }
 32         return inversions;                            // 返回逆序数
 33     }
 34 
 35     private static long count(int[] a, int[] b, int[] aux, int lo, int hi)  // 部分计数函数
 36     {
 37         long inversions = 0;
 38         if (hi <= lo)
 39             return 0;
 40         int mid = lo + (hi - lo) / 2;
 41         inversions += count(a, b, aux, lo, mid);        // 分治和归并的部分补上计算
 42         inversions += count(a, b, aux, mid + 1, hi);
 43         inversions += merge(b, aux, lo, mid, hi);
 44         return inversions;
 45     }
 46 
 47     public static long count(int[] a)                   // 可调用的计数函数
 48     {
 49         int[] b = new int[a.length];
 50         int[] aux = new int[a.length];
 51         for (int i = 0; i < a.length; i++)
 52             b[i] = a[i];
 53         return count(a, b, aux, 0, a.length - 1);
 54     }
 55 
 56     private static long brute(int[] a, int lo, int hi)  // 枚举方法计算逆序数
 57     {
 58         long inversions = 0;
 59         for (int i = lo; i <= hi; i++)
 60         {
 61             for (int j = i + 1; j <= hi; j++)
 62                 if (a[j] < a[i])
 63                     inversions++;
 64         }
 65         return inversions;
 66     }
 67 
 68     // 自定义类型版本
 69     private static <Key extends Comparable<Key>> long merge(Key[] a, Key[] aux, int lo, int mid, int hi)
 70     {
 71         long inversions = 0;
 72 
 73         for (int k = lo; k <= hi; k++)
 74             aux[k] = a[k];
 75 
 76         int i = lo, j = mid + 1;
 77         for (int k = lo; k <= hi; k++)
 78         {
 79             if (i > mid)
 80                 a[k] = aux[j++];
 81             else if (j > hi)
 82                 a[k] = aux[i++];
 83             else if (less(aux[j], aux[i]))            // 比较方法改回去了
 84             {
 85                 a[k] = aux[j++];
 86                 inversions += (mid - i + 1);
 87             }
 88             else
 89                 a[k] = aux[i++];
 90         }
 91         return inversions;
 92     }
 93 
 94     private static <Key extends Comparable<Key>> boolean less(Key v, Key w)
 95     {
 96         return (v.compareTo(w) < 0);
 97     }
 98 
 99     private static <Key extends Comparable<Key>> long count(Key[] a, Key[] b, Key[] aux, int lo, int hi)
100     {
101         long inversions = 0;
102         if (hi <= lo)
103             return 0;
104         int mid = lo + (hi - lo) / 2;
105         inversions += count(a, b, aux, lo, mid);
106         inversions += count(a, b, aux, mid + 1, hi);
107         inversions += merge(b, aux, lo, mid, hi);
108         return inversions;
109     }
110 
111     public static <Key extends Comparable<Key>> long count(Key[] a)
112     {
113         Key[] b = a.clone();
114         Key[] aux = a.clone();
115         return count(a, b, aux, 0, a.length - 1);        
116     }
117 
118     private static <Key extends Comparable<Key>> long brute(Key[] a, int lo, int hi)
119     {
120         long inversions = 0;
121         for (int i = lo; i <= hi; i++)
122         {
123             for (int j = i + 1; j <= hi; j++)
124                 if (less(a[j], a[i]))
125                     inversions++;
126         }
127         return inversions;
128     }
129 
130     public static void main(String[] args)  // 使用文件名而不是重定向来作为输入
131     {
132         In in = new In(args[0]);
133         int[] a = in.readAllInts();
134         int n = a.length;
135         int[] b = new int[n];
136         for (int i = 0; i<n; i++)
137             b[i] = a[i];
138 
139         StdOut.println(class01.count(a));
140         StdOut.println(class01.count(b));
141     }
142 }
原文地址:https://www.cnblogs.com/cuancuancuanhao/p/9753865.html