diff --git a/sort/sort.go2 b/sort/sort.go2 new file mode 100644 index 0000000..43c1ddc --- /dev/null +++ b/sort/sort.go2 @@ -0,0 +1,193 @@ +package sort + +type numeric interface { + type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64 +} + +// Sort sorts data +func Sort(type T numeric) (data []T) { + n := len(data) + quickSort(T)(data, 0, n, maxDepth(n)) +} + +func quickSort(type T numeric) (data []T, a, b, maxDepth int) { + for b-a > 12 { // Use ShellSort for slices <= 12 elements + if maxDepth == 0 { + heapSort(T)(data, a, b) + return + } + maxDepth-- + mlo, mhi := doPivot(data, a, b) + // Avoiding recursion on the larger subproblem guarantees + // a stack depth of at most lg(b-a). + if mlo-a < b-mhi { + quickSort(T)(data, a, mlo, maxDepth) + a = mhi // i.e., quickSort(data, mhi, b) + } else { + quickSort(T)(data, mhi, b, maxDepth) + b = mlo // i.e., quickSort(data, a, mlo) + } + } + if b-a > 1 { + // Do ShellSort pass with gap 6 + // It could be written in this simplified form cause b-a <= 12 + for i := a + 6; i < b; i++ { + if data[i] < data[i-6] { + data[i], data[i-6] = data[i-6], data[i] + } + } + insertionSort(T)(data, a, b) + } +} + +// Insertion sort +func insertionSort(type T numeric)(data []T, a, b int) { + for i := a + 1; i < b; i++ { + for j := i; j > a && data[j] < data[j-1]; j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +func heapSort(type T numeric)(data []T, a, b int) { + first := a + lo := 0 + hi := b - a + + // Build heap with greatest element at top. + for i := (hi - 1) / 2; i >= 0; i-- { + siftDown(T)(data, i, hi, first) + } + + // Pop elements, largest first, into end of data. + for i := hi - 1; i >= 0; i-- { + data[first], data[first+i] = data[first+1], data[first] + siftDown(T)(data, lo, i, first) + } +} + +func siftDown(type T numeric)(data []T, lo, hi, first int) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 <= hi && data[first+child] <= data[first+child+1] { + child++ + } + if data[first+root] > data[first+child] { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} + +func medianOfThree(type T numeric)(data []T, m1, m0, m2 int) { + // sort 3 elements + if data[m1] <= data[m0] { + data[m1], data[m0] = data[m0], data[m1] + } + // data[m0] <= data[m1] + if data[m2] <= data[m1] { + data[m2], data[m1] = data[m1], data[m2] + // data[m0] <= data[m2] && data[m1] <= data[m2] + if data[m1] <= data[m0] { + data[m1], data[m0] = data[m0], data[m1] + } + } + // now data[m0] <= data[m1] <= data[m2] +} + +func doPivot(type T numeric)(data []T, lo, hi int) (midlo, midhi int) { + m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow. + if hi-lo > 40 { + // Tukey's ``Ninther,'' median of three medians of three. + s := (hi - lo) / 8 + medianOfThree(T)(data, lo, lo+s, lo+2*s) + medianOfThree(T)(data, m, m-s, m+s) + medianOfThree(T)(data, hi-1, hi-1-s, hi-1-2*s) + } + medianOfThree(T)(data, lo, m, hi-1) + + // Invariants are: + // data[lo] = pivot (set up by ChoosePivot) + // data[lo < i < a] < pivot + // data[a <= i < b] <= pivot + // data[b <= i < c] unexamined + // data[c <= i < hi-1] > pivot + // data[hi-1] >= pivot + pivot := lo + a, c := lo+1, hi-1 + + for ; a < c && data[a] <= data[pivot]; a++ { + } + b := a + for { + for ; b < c && data[b] <= data[pivot]; b++ { + } + for ; b < c && data[c-1] > data[pivot]; c-- { + } + if b >= c { + break + } + // data[b] > pivot; data[c-1] <= pivot + data[b], data[c-1] = data[c-1], data[b] + b++ + c-- + } + // If hi-c<3 then there are duplicates (by property of median of nine). + // Let's be a bit more conservative, and set border to 5. + protect := hi-c < 5 + if !protect && hi-c < (hi-lo)/4 { + // Lets test some points for equality to pivot + dups := 0 + if data[hi-1] == data[pivot] { + + data[c], data[hi-1] = data[hi-1], data[c] + c++ + dups++ + } + if data[b-1] == data[pivot] { + b-- + dups++ + } + // m-lo = (hi-lo)/2 > 6 + // b-lo > (hi-lo)*3/4-1 > 8 + // ==> m < b ==> data[m] <= pivot + if data[m] == data[pivot] { // data[m] = pivot + data[m], data[b-1] = data[b-1], data[m] + b-- + dups++ + } + // if at least 2 points are equal to pivot, assume skewed distribution + protect = dups > 1 + } + if protect { + // Protect against a lot of duplicates + for { + for ; a < b && data[b-1] == data[pivot]; b-- { + } + for ; a < b && data[a] < data[pivot]; a++ { + } + if a >= b { + break + } + data[a], data[b-1] = data[b-1], data[a] + a++ + b-- + } + } + // Swap pivot into middle + data[pivot], data[b-1] = data[b-1], data[pivot] + return b - 1, c +} + +func maxDepth(n int) int { + var depth int + for i := n; i > 0; i >>= 1 { + depth++ + } + return depth * 2 +} diff --git a/sort/sort_test.go2 b/sort/sort_test.go2 new file mode 100644 index 0000000..64eff0f --- /dev/null +++ b/sort/sort_test.go2 @@ -0,0 +1,39 @@ +package sort + +import "testing" + +func equal(type T comparable)(a, b []T) bool { + for idx := range a { + if a[idx] != b[idx] { + return false + } + } + return true +} + +func TestSortInt(t *testing.T) { + given := []int{3,2,1} + want := []int{1,2,3} + Sort(int)(given) + if !equal(int)(given, want) { + t.Errorf("sort failed: got %v want %v", given, want) + } +} + +func TestSortFloat(t *testing.T) { + given := []float64{3.3,2.2,1.1} + want := []float64{1.1,2.2,3.3} + Sort(float64)(given) + if !equal(float64)(given, want) { + t.Errorf("sort failed: got %v want %v", given, want) + } +} + +func TestSortByte(t *testing.T) { + given := []byte{3,2,1} + want := []byte{1,2,3} + Sort(byte)(given) + if !equal(byte)(given, want) { + t.Errorf("sort failed: got %v want %v", given, want) + } +}