N要素の分岐しない odd-even mergesort (GCC+Linux/x86_64用)
(こちらの記事の続き)
(ジェネレータを生成)
% g++ -O2 -Wall gen_sort.cpp -o gen_sort
(8要素向けのソート関数を作成)
% ./gen_sort 0 8 > oem_sort_8.s
% gcc -c oem_sort8.s
(ちゃんと作られたようだ)
% nm oem_sort_8.o
0000000000000000 T odd_even_sort8
(テストコード書き)
% cat > test.c
#include <assert.h>
void odd_even_sort8(unsigned int* d);
int main() {
int i;
unsigned int s[] = {7,6,5,4,1,2,3,4};
odd_even_sort8(s);
for(i = 0; i < 7; ++i) {
assert(s[i] <= s[i + 1]);
}
return 0;
}
^D
(テストコードをコンパイル、ソート関数とリンク)
% gcc -Wall test.c oem_sort_8.o
(無事実行)
% ./a.out
%
(ジェネレータの使いかた)
% ./gen_sort
usage: ./gen_sort [0|1] N
0: odd-even sorting, 1: bitonic sorting
N: 4, 8, 16, 32, 64, ...以下ソースコード。
// gen_sort.cpp
// こんなものをC++で書くのは狂気の沙汰ですね..
#include <string>
#include <sstream>
#include <cstddef>
#include <cassert>
#include <iostream>
#include <boost/bind.hpp>
#include <boost/array.hpp>
#include <boost/utility.hpp>
#include <boost/function.hpp>
struct Reg {
const char* const reg;
const char* const reg_clob;
const bool callee_save;
};
class x86_64 : boost::noncopyable {
public:
explicit x86_64(std::ostream& ost) : ost_(ost) {}
template<typename T>
int out_asm(std::size_t n, T algo) {
int count = 0;
out_prologue(n, algo);
algo(n, boost::bind(&x86_64::out_cmp_swap, this, _1, _2, _3, boost::ref(count)));
out_epilogue(n);
return count;
}
private:
static const char* const SRC_REG;
static const boost::array<Reg, 9> REGS;
static const boost::array<Reg, 3> TMP_REGS;
std::ostream& ost_;
// callback
void out_cmp_swap(std::size_t a, std::size_t b, int direction, int& count) {
std::string mem1 = a < REGS.size() ? REGS[a].reg : get_mem_location(a);
std::string mem2 = b < REGS.size() ? REGS[b].reg : get_mem_location(b);
const char* const cc = direction == 1 ? "a " : "b ";
if (a < REGS.size() && b < REGS.size()) {
ost_ << "\tmov " << mem2 << ", " << TMP_REGS[0].reg << std::endl;
ost_ << "\tcmp " << mem2 << ", " << mem1 << std::endl;
ost_ << "\tcmov" << cc << mem1 << ", " << mem2 << std::endl;
ost_ << "\tcmov" << cc << TMP_REGS[0].reg << ", " << mem1 << std::endl;
} else {
ost_ << "\tmov " << mem2 << ", " << TMP_REGS[0].reg << std::endl;
ost_ << "\tmov " << mem1 << ", " << TMP_REGS[1].reg << std::endl;
ost_ << "\tcmp " << TMP_REGS[0].reg << ", " << TMP_REGS[1].reg << std::endl;
ost_ << "\tcmov" << cc << TMP_REGS[1].reg << ", " << TMP_REGS[2].reg << std::endl;
ost_ << "\tcmov" << cc << TMP_REGS[0].reg << ", " << TMP_REGS[1].reg << std::endl;
ost_ << "\tcmov" << cc << TMP_REGS[2].reg << ", " << TMP_REGS[0].reg << std::endl;
ost_ << "\tmov " << TMP_REGS[0].reg << ", " << mem2 << std::endl;
ost_ << "\tmov " << TMP_REGS[1].reg << ", " << mem1 << std::endl;
}
++count;
}
void push_clobber_regs(std::size_t n) {
for(size_t i = 0; i < n; ++i) {
if (i < REGS.size() && REGS[i].callee_save) {
ost_ << "\tpushq " << REGS[i].reg_clob << std::endl;
}
}
if (n > 8) {
for(size_t i = 0; i < TMP_REGS.size() ; ++i) {
ost_ << "\tpushq " << TMP_REGS[i].reg_clob << std::endl;
}
} else {
ost_ << "\tpushq " << TMP_REGS[0].reg_clob << std::endl;
}
return;
}
void pop_clobber_regs(std::size_t n) {
if (n > 8) {
for(size_t i = TMP_REGS.size(); i > 0; --i) {
ost_ << "\tpopq " << TMP_REGS[i - 1].reg_clob << std::endl;
}
} else {
ost_ << "\tpopq " << TMP_REGS[0].reg_clob << std::endl;
}
for(size_t i = n; i > 0; --i) {
if (i - 1 < REGS.size() && REGS[i - 1].callee_save) {
ost_ << "\tpopq " << REGS[i - 1].reg_clob << std::endl;
}
}
return;
}
void load_regs(std::size_t n) {
for(size_t i = 0; i < n; ++i) {
if (i < REGS.size()) {
ost_ << "\tmovl " << get_mem_location(i) << ", " << REGS[i].reg << std::endl;
}
}
}
void store_regs(std::size_t n) {
for(size_t i = 0; i < n; ++i) {
if (i < REGS.size()) {
ost_ << "\tmovl " << REGS[i].reg << ", " << get_mem_location(i) << std::endl;
}
}
}
template <typename T>
void out_prologue(std::size_t n, T algo) {
ost_ << ".text" << std::endl
<< ".globl " << algo(n) << std::endl
<< algo(n) << ":" << std::endl;
push_clobber_regs(n);
load_regs(n);
}
void out_epilogue(std::size_t n) {
store_regs(n);
pop_clobber_regs(n);
ost_ << "\tretq" << std::endl; // leaveq 不要
}
std::string get_mem_location(std::size_t i) {
std::ostringstream ret;
if (i > 0) {
ret << 4 * i; /* 4 == sizeof(unsigned int) on x86_64 */
}
ret << "(" << SRC_REG << ")";
return ret.str();
}
};
const char* const x86_64::SRC_REG = "%rdi";
const boost::array<Reg, 9> x86_64::REGS
= {{
{ "%eax", "%eax", false},
{ "%esi", "%esi", false},
{ "%edx", "%edx", false},
{ "%ecx", "%ecx", false},
{ "%r8d", "%r8" , false},
{ "%r9d", "%r9" , false},
{ "%r10d", "%r10", false},
{ "%r11d", "%r11", false},
{ "%r12d", "%r12", true},
}};
const boost::array<Reg, 3> x86_64::TMP_REGS
= {{
{ "%r13d", "%r13", true},
{ "%r14d", "%r14", true},
{ "%r15d", "%r15", true},
}};
class odd_even_mergesort {
// 参考: http://www.inf.fh-flensburg.de/lang/algorithmen/sortieren/networks/oemen.htm
public:
void operator() (std::size_t n, boost::function<void (std::size_t, std::size_t, int)> out) {
assert(n > 0 && n <= (sizeof(dmy_) / sizeof(dmy_[0]))
&& __builtin_popcount(n) == 1);
mergesort(dmy_, n, out);
}
std::string operator() (std::size_t n) {
std::stringstream ss;
ss << "odd_even_sort" << n;
return ss.str();
}
private:
static unsigned int dmy_[2 * 1024 * 1024]; // 2M要素まで
template <typename T>
void compare_and_swap(unsigned int* a, unsigned int* b, T out) {
// fprintf(stderr, "# cmp %td, %td\n", a - dmy_, b - dmy_);
out(a - dmy_, b - dmy_, 1);
}
template <typename T>
void merge(unsigned int* d, int n, int skip, T out) {
if (n > 2) {
merge(d, n / 2, skip * 2, out);
merge(d + skip, n / 2, skip * 2, out);
for(int i = 1; i <= n - 3; i += 2) {
compare_and_swap(&d[i * skip], &d[(i + 1) * skip], out);
}
} else {
compare_and_swap(&d[0], &d[skip], out);
}
}
template <typename T>
void mergesort(unsigned int* d, int n, T out) {
if (n > 1) {
mergesort(d, n / 2, out);
mergesort(d + n / 2, n / 2, out);
merge(d, n, 1, out);
}
}
};
unsigned int odd_even_mergesort::dmy_[];
class bitonic_sort {
// おまけ (参考: http://jyoken.net/2005/kenpatsu/enari_oraf/)
public:
void operator() (std::size_t n, boost::function<void (std::size_t, std::size_t, int)> out) {
std::size_t cnt, compare_place, bs, interval;
int direction;
assert(n > 0 && __builtin_popcount(n) == 1);
for(bs = 2;bs <= n;bs *= 2) {
direction = 1;
n += n % bs;
for (cnt = 0 ;cnt < n / bs; cnt++) {
for (interval = bs / 2; interval >= 1; interval = interval / 2) {
for (compare_place = 0; compare_place + interval < bs; compare_place++) {
out (bs * cnt + compare_place,
bs * cnt + compare_place + interval,
direction);
if (interval == compare_place + 1) {
compare_place += interval;
}
}
}
direction *= -1;
}
}
}
std::string operator() (std::size_t n) {
std::stringstream ss;
ss << "bitonic_sort" << n;
return ss.str();
}
};
int main(int argc, char** argv) {
int type = argc > 2 ? std::atoi(argv[1]) : 0;
std::size_t n = argc > 2 ? std::atoi(argv[2]) : 0;
if (!(argc > 2) ||
!(type == 0 || type == 1) ||
n < 4 || __builtin_popcount(n) != 1) {
std::cout << "usage: " << argv[0] << " [0|1] N" << std::endl
<< "0: odd-even sorting, 1: bitonic sorting" << std::endl
<< "N: 4, 8, 16, 32, 64, ..." << std::endl;
return 1;
}
int count = 0;
x86_64 x(std::cout);
switch(type) {
case 0:
count = x.out_asm(n, odd_even_mergesort());
break;
case 1:
count = x.out_asm(n, bitonic_sort());
break;
}
std::cout << "# " << n << " elements, " << count << " comparators" << std::endl;
return 0;
}