N要素の分岐しない odd-even mergesort (GCC+Linux/x86_64用)


(こちらの記事の続き)

(ジェネレータを生成)
% g++ -O2 -Wall gen_sort.cpp -o gen_sort

(8要素向けのソート関数を作成)
% ./gen_sort 0 8 > oem_sort_8.s
% gcc -c oem_sort8.s

(ちゃんと作られたようだ)
% nm oem_sort_8.o
0000000000000000 T odd_even_sort8

(テストコード書き)
% cat > test.c
#include <assert.h>
void odd_even_sort8(unsigned int* d);

int main() {
  int i;
  unsigned int s[] = {7,6,5,4,1,2,3,4};
  odd_even_sort8(s);
  for(i = 0; i < 7; ++i) {
    assert(s[i] <= s[i + 1]);
  }
  return 0;
}
^D 

(テストコードをコンパイル、ソート関数とリンク)
% gcc -Wall test.c oem_sort_8.o

(無事実行)
% ./a.out
%

(ジェネレータの使いかた)
% ./gen_sort
usage: ./gen_sort [0|1] N
 0: odd-even sorting, 1: bitonic sorting
 N: 4, 8, 16, 32, 64, ...

以下ソースコード

// gen_sort.cpp
// こんなものをC++で書くのは狂気の沙汰ですね..

#include <string>
#include <sstream>
#include <cstddef>
#include <cassert>
#include <iostream>

#include <boost/bind.hpp>
#include <boost/array.hpp>
#include <boost/utility.hpp>
#include <boost/function.hpp>

struct Reg {
  const char* const reg;
  const char* const reg_clob;
  const bool        callee_save;
};

class x86_64 : boost::noncopyable {
public:
  explicit x86_64(std::ostream& ost) : ost_(ost) {}

  template<typename T>
  int out_asm(std::size_t n, T algo) {
    int count = 0;

    out_prologue(n, algo);
    algo(n, boost::bind(&x86_64::out_cmp_swap, this, _1, _2, _3, boost::ref(count)));
    out_epilogue(n);

    return count;
  }
  
private:

  static const char* const SRC_REG;
  static const boost::array<Reg, 9> REGS; 
  static const boost::array<Reg, 3> TMP_REGS; 

  std::ostream& ost_;

  // callback
  void out_cmp_swap(std::size_t a, std::size_t b, int direction, int& count) {
    
    std::string mem1 = a < REGS.size() ? REGS[a].reg : get_mem_location(a);
    std::string mem2 = b < REGS.size() ? REGS[b].reg : get_mem_location(b);
    const char* const cc = direction == 1 ? "a " : "b ";
    
    if (a < REGS.size() && b < REGS.size()) {
      ost_ << "\tmov " << mem2 << ", " << TMP_REGS[0].reg << std::endl;
      ost_ << "\tcmp " << mem2 << ", " << mem1 << std::endl;
      ost_ << "\tcmov" << cc <<  mem1 << ", " << mem2 << std::endl;
      ost_ << "\tcmov" << cc <<  TMP_REGS[0].reg << ", " << mem1 << std::endl;
    } else {
      ost_ << "\tmov " << mem2 << ", " << TMP_REGS[0].reg << std::endl;
      ost_ << "\tmov " << mem1 << ", " << TMP_REGS[1].reg << std::endl;
      ost_ << "\tcmp " << TMP_REGS[0].reg << ", " << TMP_REGS[1].reg << std::endl;
      ost_ << "\tcmov" << cc << TMP_REGS[1].reg << ", " << TMP_REGS[2].reg << std::endl;
      ost_ << "\tcmov" << cc << TMP_REGS[0].reg << ", " << TMP_REGS[1].reg << std::endl;
      ost_ << "\tcmov" << cc << TMP_REGS[2].reg << ", " << TMP_REGS[0].reg << std::endl;
      ost_ << "\tmov " << TMP_REGS[0].reg << ", " << mem2 << std::endl;
      ost_ << "\tmov " << TMP_REGS[1].reg << ", " << mem1 << std::endl;
    }

    ++count;
  }

  void push_clobber_regs(std::size_t n) {
    for(size_t i = 0; i < n; ++i) {
      if (i < REGS.size() && REGS[i].callee_save) {
	ost_ << "\tpushq " << REGS[i].reg_clob << std::endl;
      }
    }
    if (n > 8) {
      for(size_t i = 0; i < TMP_REGS.size() ; ++i) {
	ost_ << "\tpushq " << TMP_REGS[i].reg_clob << std::endl;
      }
    } else {
      ost_ << "\tpushq " << TMP_REGS[0].reg_clob << std::endl;
    }

    return;
  }

  void pop_clobber_regs(std::size_t n) {
    if (n > 8) {
      for(size_t i = TMP_REGS.size(); i > 0; --i) { 
	ost_ << "\tpopq " << TMP_REGS[i - 1].reg_clob << std::endl;
      }
    } else {
      ost_ << "\tpopq " << TMP_REGS[0].reg_clob << std::endl;
    }
   
    for(size_t i = n; i > 0; --i) {
      if (i - 1 < REGS.size() && REGS[i - 1].callee_save) {
	ost_ << "\tpopq " << REGS[i - 1].reg_clob << std::endl;
      }
    }

    return;
  }
  
  void load_regs(std::size_t n) {
    for(size_t i = 0; i < n; ++i) {
      if (i < REGS.size()) {
	ost_ << "\tmovl " << get_mem_location(i) << ", " << REGS[i].reg << std::endl;
      }
    }
  }
  
  void store_regs(std::size_t n) {
    for(size_t i = 0; i < n; ++i) {
      if (i < REGS.size()) {
	ost_ << "\tmovl "  << REGS[i].reg << ", " << get_mem_location(i) << std::endl;
      }
    }
  }

  template <typename T>
  void out_prologue(std::size_t n, T algo) {
    ost_ << ".text" << std::endl
	 << ".globl " << algo(n) << std::endl
	 << algo(n) << ":" << std::endl;
    push_clobber_regs(n);
    load_regs(n);
  }
  
  void out_epilogue(std::size_t n) {
    store_regs(n);
    pop_clobber_regs(n);
    ost_ << "\tretq" << std::endl; // leaveq 不要
  }
  
  std::string get_mem_location(std::size_t i) {
    std::ostringstream ret;
    if (i > 0) {
      ret << 4 * i; /* 4 == sizeof(unsigned int) on x86_64 */
    }
    ret << "(" << SRC_REG << ")";
    return ret.str(); 
  }
};

const char* const x86_64::SRC_REG = "%rdi";

const boost::array<Reg, 9> x86_64::REGS 
  = {{
      { "%eax",  "%eax", false},
      { "%esi",  "%esi", false},
      { "%edx",  "%edx", false},
      { "%ecx",  "%ecx", false},
      { "%r8d",  "%r8" , false},
      { "%r9d",  "%r9" , false},
      { "%r10d", "%r10", false},
      { "%r11d", "%r11", false},
      { "%r12d", "%r12", true},
    }};

const boost::array<Reg, 3> x86_64::TMP_REGS 
  = {{
      { "%r13d", "%r13", true},
      { "%r14d", "%r14", true},
      { "%r15d", "%r15", true},
    }};

class odd_even_mergesort {
// 参考: http://www.inf.fh-flensburg.de/lang/algorithmen/sortieren/networks/oemen.htm
public:
  void operator() (std::size_t n, boost::function<void (std::size_t, std::size_t, int)> out) {
    assert(n > 0 && n <= (sizeof(dmy_) / sizeof(dmy_[0]))
	   && __builtin_popcount(n) == 1);
    mergesort(dmy_, n, out);
  }

  std::string operator() (std::size_t n) {
    std::stringstream ss;
    ss << "odd_even_sort" << n;
    return ss.str();
  }

private:
  static unsigned int dmy_[2 * 1024 * 1024]; // 2M要素まで

  template <typename T>  
  void compare_and_swap(unsigned int* a, unsigned int* b, T out) {
    //  fprintf(stderr, "# cmp %td, %td\n", a - dmy_, b - dmy_);
    out(a - dmy_, b - dmy_, 1);
  }

  template <typename T>  
  void merge(unsigned int* d, int n, int skip, T out) {
    if (n > 2) {
      merge(d,        n / 2, skip * 2, out);
      merge(d + skip, n / 2, skip * 2, out);
      for(int i = 1; i <= n - 3; i += 2) {
	compare_and_swap(&d[i * skip], &d[(i + 1) * skip], out);
      }
    } else {
      compare_and_swap(&d[0], &d[skip], out);
    }
  }

  template <typename T>  
  void mergesort(unsigned int* d, int n, T out) {
    if (n > 1) {
      mergesort(d,         n / 2, out);
      mergesort(d + n / 2, n / 2, out);
      merge(d, n, 1, out);
    }
  }
};

unsigned int odd_even_mergesort::dmy_[];

class bitonic_sort { 
// おまけ (参考: http://jyoken.net/2005/kenpatsu/enari_oraf/)
public:
  void operator() (std::size_t n, boost::function<void (std::size_t, std::size_t, int)> out) {
    std::size_t cnt, compare_place, bs, interval;
    int direction;

    assert(n > 0 && __builtin_popcount(n) == 1);
    for(bs = 2;bs <= n;bs *= 2) {
      direction = 1;
      n += n % bs;
      for (cnt = 0 ;cnt < n / bs; cnt++) {
	for (interval = bs / 2; interval >= 1; interval = interval / 2) {
	  for (compare_place = 0; compare_place + interval < bs; compare_place++) {
	    out (bs * cnt + compare_place,
		 bs * cnt + compare_place + interval,
		 direction);
	    if (interval == compare_place + 1) {
	      compare_place += interval;
	    }
	  }
	}
	direction *= -1;
      }
    }
  }

  std::string operator() (std::size_t n) {
    std::stringstream ss;
    ss << "bitonic_sort" << n;
    return ss.str();
  }
};

int main(int argc, char** argv) {
  int type      = argc > 2 ? std::atoi(argv[1]) : 0;
  std::size_t n = argc > 2 ? std::atoi(argv[2]) : 0;
  if (!(argc > 2) || 
      !(type == 0 || type == 1) ||
      n < 4 || __builtin_popcount(n) != 1) {
    std::cout << "usage: " << argv[0] << " [0|1] N" << std::endl
	      << "0: odd-even sorting, 1: bitonic sorting" << std::endl
	      << "N: 4, 8, 16, 32, 64, ..." << std::endl;
    return 1;
  }

  int count = 0;

  x86_64 x(std::cout);
  switch(type) {
  case 0:
    count = x.out_asm(n, odd_even_mergesort());
    break;
  case 1:
    count = x.out_asm(n, bitonic_sort());
    break;
  }

  std::cout << "# " << n << " elements, " << count << " comparators" << std::endl;
  return 0;
}