Skip to content

process query batches in parallel #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metagraph/src/cli/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,7 @@ if (advanced) {
// fprintf(stderr, "\t --cache-size [INT] \tnumber of uncompressed rows to store in the cache [0]\n");
fprintf(stderr, "\t --batch-size [INT] \tquery batch size in bp (0 to disable batch query) [100'000'000]\n");
if (advanced) {
fprintf(stderr, "\t --threads-each [INT]\tnumber of parallel batches [1]\n");
fprintf(stderr, "\t --RA-ivbuff-size [INT] \tsize (in bytes) of int_vector_buffer used in random access mode (e.g. by row disk annotator) [16384]\n");
}
fprintf(stderr, "\n");
Expand Down
89 changes: 46 additions & 43 deletions metagraph/src/cli/query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,8 @@ QueryExecutor::batched_query_fasta(seq_io::FastaParser &fasta_parser,
size_t seq_count = 0;
size_t num_bp = 0;

ThreadPool thread_pool(config_.parallel_each);
size_t threads_per_batch = get_num_threads() / config_.parallel_each;
while (it != end) {
Timer batch_timer;

Expand All @@ -1278,62 +1280,63 @@ QueryExecutor::batched_query_fasta(seq_io::FastaParser &fasta_parser,
// A generator that can be called multiple times until all sequences
// are called
std::vector<QuerySequence> seq_batch;
std::vector<Alignment> alignments_batch;
num_bytes_read = 0;

for ( ; it != end && num_bytes_read <= batch_size; ++it) {
seq_batch.push_back(QuerySequence { seq_count++, it->name.s, it->seq.s });
num_bytes_read += it->seq.l;
}

// Align sequences ahead of time on full graph if we don't have batch_align
if (aligner_config_ && !config_.batch_align) {
alignments_batch.resize(seq_batch.size());
logger->trace("Aligning sequences from batch against the full graph...");
batch_timer.reset();

#pragma omp parallel for num_threads(get_num_threads()) schedule(dynamic)
for (size_t i = 0; i < seq_batch.size(); ++i) {
// Set alignment for this seq_batch
alignments_batch[i] = align_sequence(&seq_batch[i].sequence,
anno_graph_, *aligner_config_);
thread_pool.enqueue([&](std::vector<QuerySequence> seq_batch, uint64_t num_bytes_read) {
std::vector<Alignment> alignments_batch;
// Align sequences ahead of time on full graph if we don't have batch_align
if (aligner_config_ && !config_.batch_align) {
alignments_batch.resize(seq_batch.size());
logger->trace("Aligning sequences from batch against the full graph...");
batch_timer.reset();

#pragma omp parallel for num_threads(threads_per_batch) schedule(dynamic)
for (size_t i = 0; i < seq_batch.size(); ++i) {
// Set alignment for this seq_batch
alignments_batch[i] = align_sequence(&seq_batch[i].sequence,
anno_graph_, *aligner_config_);
}
logger->trace("Sequences alignment took {} sec", batch_timer.elapsed());
batch_timer.reset();
}
logger->trace("Sequences alignment took {} sec", batch_timer.elapsed());
batch_timer.reset();
}

// Construct the query graph for this batch
auto query_graph = construct_query_graph(
anno_graph_,
[&](auto callback) {
for (const auto &seq : seq_batch) {
callback(seq.sequence);
}
},
get_num_threads(),
aligner_config_ && config_.batch_align ? &config_ : NULL
);
// Construct the query graph for this batch
auto query_graph = construct_query_graph(
anno_graph_,
[&](auto callback) {
for (const auto &seq : seq_batch) {
callback(seq.sequence);
}
},
threads_per_batch,
aligner_config_ && config_.batch_align ? &config_ : NULL
);

auto query_graph_construction = batch_timer.elapsed();
batch_timer.reset();
auto query_graph_construction = batch_timer.elapsed();
batch_timer.reset();

#pragma omp parallel for num_threads(get_num_threads()) schedule(dynamic)
for (size_t i = 0; i < seq_batch.size(); ++i) {
SeqSearchResult search_result
= query_sequence(std::move(seq_batch[i]), *query_graph, config_,
config_.batch_align ? aligner_config_.get() : NULL);
#pragma omp parallel for num_threads(threads_per_batch) schedule(dynamic)
for (size_t i = 0; i < seq_batch.size(); ++i) {
SeqSearchResult search_result
= query_sequence(std::move(seq_batch[i]), *query_graph, config_,
config_.batch_align ? aligner_config_.get() : NULL);

if (alignments_batch.size())
search_result.get_alignment() = std::move(alignments_batch[i]);
if (alignments_batch.size())
search_result.get_alignment() = std::move(alignments_batch[i]);

callback(search_result);
}
callback(search_result);
}

logger->trace("Query graph constructed for batch of sequences"
" with {} bases from '{}' in {:.5f} sec, query redundancy: {:.2f} bp/kmer, queried in {:.5f} sec",
num_bytes_read, fasta_parser.get_filename(), query_graph_construction,
(double)num_bytes_read / query_graph->get_graph().num_nodes(),
batch_timer.elapsed());
logger->trace("Query graph constructed for batch of sequences"
" with {} bases from '{}' in {:.5f} sec, query redundancy: {:.2f} bp/kmer, queried in {:.5f} sec",
num_bytes_read, fasta_parser.get_filename(), query_graph_construction,
(double)num_bytes_read / query_graph->get_graph().num_nodes(),
batch_timer.elapsed());
}, std::move(seq_batch), num_bytes_read);

num_bp += num_bytes_read;
}
Expand Down
Loading