Skip to content

support conversion from RowDiff<Mutli-BRWT> to RowDiff<RowFlat> #520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions metagraph/src/annotation/annotation_converters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,31 @@ void convert_to_row_diff<RowDiffRowFlatAnnotator>(
logger->trace("Annotation converted");
}


template <>
void convert_to_row_diff<RowDiffRowFlatAnnotator>(const RowDiffBRWTAnnotator &anno,
const std::string &outfbase) {
const auto &fname = utils::make_suffix(outfbase, RowDiffRowFlatAnnotator::kExtension);
std::ofstream out = utils::open_new_ofstream(fname);
if (!out.good())
throw std::ofstream::failure("Can't write to " + fname);

anno.get_label_encoder().serialize(out);

// serialize RowDiff<RowFlat<>>
out.write("v2.0", 4);
anno.get_matrix().anchor().serialize(out);
anno.get_matrix().fork_succ().serialize(out);
out.close();

RowFlat<>::serialize([&](auto callback) { anno.get_matrix().diffs().call_rows(callback); },
anno.get_matrix().diffs().num_columns(),
anno.get_matrix().diffs().num_rows(),
anno.get_matrix().diffs().num_relations(),
fname, true);
logger->trace("Annotation converted");
}

template <>
void convert_to_row_diff<RowDiffRowSparseAnnotator>(
const std::vector<std::string> &files,
Expand Down
4 changes: 4 additions & 0 deletions metagraph/src/annotation/annotation_converters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ void convert_to_row_diff(const std::vector<std::string> &files,
size_t num_threads,
size_t mem_bytes);

template <class RowDiffAnnotator>
void convert_to_row_diff(const RowDiffBRWTAnnotator &anno,
const std::string &outfbase);

void merge_row_compressed(const std::vector<std::string> &filenames,
const std::string &outfile);

Expand Down
43 changes: 43 additions & 0 deletions metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <queue>
#include <numeric>

#include <progress_bar.hpp>

#include "common/algorithms.hpp"
#include "common/serialization.hpp"
#include "common/utils/template_utils.hpp"
Expand All @@ -12,6 +14,9 @@ namespace mtg {
namespace annot {
namespace matrix {

const size_t kNumRowsInBlock = 250'000;


bool BRWT::get(Row row, Column column) const {
assert(row < num_rows());
assert(column < num_columns());
Expand Down Expand Up @@ -54,6 +59,44 @@ BRWT::get_rows(const std::vector<Row> &row_ids) const {
return rows;
}

void BRWT::slice_rows(Row begin, Row end, Vector<Column> *slice) const {
// TODO: It may be faster if index columns are queried in ranges instead of with element-wise
// access queries.
slice_rows(utils::arange<Row>(begin, end - begin), slice);
}

void BRWT::call_rows(const std::function<void(const Vector<Column> &)> &callback,
bool show_progress) const {
Vector<Column> slice;
ProgressBar progress_bar(num_rows(), "Queried BRWT rows", std::cerr, !show_progress);

#pragma omp parallel for ordered num_threads(get_num_threads()) schedule(dynamic) private(slice)
for (uint64_t i = 0; i < num_rows(); i += kNumRowsInBlock) {
uint64_t begin = i;
uint64_t end = std::min(i + kNumRowsInBlock, num_rows());
assert(begin <= end);

slice.resize(0);
slice_rows(begin, end, &slice);

#pragma omp ordered
{
Vector<Column> row;
for (auto row_begin = slice.begin(); row_begin < slice.end(); ) {
// every row in `slice` ends with `-1`
auto row_end = std::find(row_begin, slice.end(),
std::numeric_limits<Column>::max());
row.assign(row_begin, row_end);
std::sort(row.begin(), row.end());
callback(row);
++progress_bar;
row_begin = row_end + 1;
}
}
}
}


std::vector<Vector<std::pair<BRWT::Column, uint64_t>>>
BRWT::get_column_ranks(const std::vector<Row> &row_ids) const {
std::vector<Vector<std::pair<Column, uint64_t>>> rows(row_ids.size());
Expand Down
4 changes: 4 additions & 0 deletions metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class BRWT : public BinaryMatrix, public GetEntrySupport {
// query row and get ranks of each set bit in its column
std::vector<Vector<std::pair<Column, uint64_t>>>
get_column_ranks(const std::vector<Row> &rows) const;
void call_rows(const std::function<void(const Vector<Column> &)> &callback,
bool show_progress = common::get_verbose()) const;

bool load(std::istream &in) override;
void serialize(std::ostream &out) const override;
Expand All @@ -57,6 +59,8 @@ class BRWT : public BinaryMatrix, public GetEntrySupport {
template <typename T>
void slice_rows(const std::vector<Row> &rows, Vector<T> *slice) const;

void slice_rows(Row begin, Row end, Vector<Column> *slice) const;

// assigns columns to the child nodes
RangePartition assignments_;
std::unique_ptr<bit_vector> nonzero_rows_;
Expand Down
11 changes: 11 additions & 0 deletions metagraph/src/cli/transform_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,17 @@ int transform_annotation(Config *config) {
logger->trace("Serialized to {}", config->outfbase);
}
}
} else if (input_anno_type == Config::RowDiffBRWT && config->anno_type == Config::RowDiffRowFlat) {
if (files.size() != 1) {
logger->error("Can only convert row_diff_brwt annotations one at a time");
exit(1);
}
RowDiffBRWTAnnotator annotator;
annotator.load(files[0]);

convert_to_row_diff<RowDiffRowFlatAnnotator>(annotator, config->outfbase);
logger->trace("Serialized to {}", config->outfbase);

} else {
logger->error(
"Conversion to other representations is not implemented for {} "
Expand Down
Loading