Skip to content

Importance Matrix calculation #4861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
imatrix: WIP
  • Loading branch information
Kawrakow committed Jan 10, 2024
commit 055a0c2e12d6c558638e71b449efe5162ca33dd0
144 changes: 108 additions & 36 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,25 @@ struct StatParams {
bool collect_output_weight = false;
};

static void ik_save_statistics(const char * fname, const std::unordered_map<std::string, Stats>& stats, int ncall) {
std::ofstream out(fname, std::ios::binary);
int n_entries = stats.size();
out.write((const char*)&n_entries, sizeof(n_entries));
for (auto& p : stats) {
int len = p.first.size();
out.write((const char*)&len, sizeof(len));
out.write(p.first.c_str(), len);
out.write((const char*)&p.second.ncall, sizeof(p.second.ncall));
int nval = p.second.values.size();
out.write((const char*)&nval, sizeof(nval));
if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float));
}
fprintf(stderr, "%s: stored collected data after %d calls in %s\n",__func__,ncall,fname);
}
class IMatrixCollector {
public:
IMatrixCollector() = default;
void set_parameters(StatParams&& params) { m_params = std::move(params); }
void collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1);
void save_imatrix() const;
private:
std::unordered_map<std::string, Stats> m_stats;
StatParams m_params;
std::mutex m_mutex;
int m_last_call = 0;
};

static void ik_collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) {
static int last_call = 0;
static std::mutex mutex;
void IMatrixCollector::collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) {
if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return;
//if (strncmp(src0->name, "blk.", 4) != 0 && strcmp(src0->name, "output.weight") != 0) return;
if (strncmp(src0->name, "blk.", 4) != 0) return;
std::lock_guard<std::mutex> lock(mutex);
auto& g_stats = ik_get_stats();
auto& e = g_stats[src0->name];
if (!(strncmp(src0->name, "blk.", 4) == 0 || (m_params.collect_output_weight && strcmp(src0->name, "output.weight") == 0))) return;
//if (strncmp(src0->name, "blk.", 4) != 0) return;
std::lock_guard<std::mutex> lock(m_mutex);
auto& e = m_stats[src0->name];
if (e.values.empty()) {
e.values.resize(src1->ne[0], 0);
}
Expand All @@ -66,21 +60,89 @@ static void ik_collect_imatrix(const struct ggml_tensor * src0, const struct ggm
exit(1); //GGML_ASSERT(false);
}
++e.ncall;
printf("%s[%d]: %s, %d x %d, %d\n",__func__,last_call,src0->name,(int)src1->ne[0],(int)src1->ne[1],(int)src1->type);
printf("%s[%d]: %s, %d x %d, %d\n",__func__,m_last_call,src0->name,(int)src1->ne[0],(int)src1->ne[1],(int)src1->type);
for (int row = 0; row < (int)src1->ne[1]; ++row) {
const float * x = (const float *)src1->data + row * src1->ne[0];
for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[j] += x[j]*x[j];
}
}
if (e.ncall > last_call) {
last_call = e.ncall;
if (last_call % 10 == 0) {
ik_save_statistics("stats.dat", g_stats, last_call);
if (e.ncall > m_last_call) {
m_last_call = e.ncall;
if (m_last_call % m_params.n_output_frequency == 0) {
save_imatrix();
}
}
}

void IMatrixCollector::save_imatrix() const {
const char * fname = m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str();
std::ofstream out(fname, std::ios::binary);
int n_entries = m_stats.size();
out.write((const char*)&n_entries, sizeof(n_entries));
for (auto& p : m_stats) {
int len = p.first.size();
out.write((const char*)&len, sizeof(len));
out.write(p.first.c_str(), len);
out.write((const char*)&p.second.ncall, sizeof(p.second.ncall));
int nval = p.second.values.size();
out.write((const char*)&nval, sizeof(nval));
if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float));
}
fprintf(stderr, "%s: stored collected data after %d calls in %s\n",__func__,m_last_call,fname);
}

static IMatrixCollector g_collector;

//static void ik_save_statistics(const char * fname, const std::unordered_map<std::string, Stats>& stats, int ncall) {
// std::ofstream out(fname, std::ios::binary);
// int n_entries = stats.size();
// out.write((const char*)&n_entries, sizeof(n_entries));
// for (auto& p : stats) {
// int len = p.first.size();
// out.write((const char*)&len, sizeof(len));
// out.write(p.first.c_str(), len);
// out.write((const char*)&p.second.ncall, sizeof(p.second.ncall));
// int nval = p.second.values.size();
// out.write((const char*)&nval, sizeof(nval));
// if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float));
// }
// fprintf(stderr, "%s: stored collected data after %d calls in %s\n",__func__,ncall,fname);
//}

static void ik_collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) {
g_collector.collect_imatrix(src0, src1);
//static int last_call = 0;
//static std::mutex mutex;
//if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return;
////if (strncmp(src0->name, "blk.", 4) != 0 && strcmp(src0->name, "output.weight") != 0) return;
//if (strncmp(src0->name, "blk.", 4) != 0) return;
//std::lock_guard<std::mutex> lock(mutex);
//auto& g_stats = ik_get_stats();
//auto& e = g_stats[src0->name];
//if (e.values.empty()) {
// e.values.resize(src1->ne[0], 0);
//}
//else if (e.values.size() != (size_t)src1->ne[0]) {
// fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", src0->name, (int)e.values.size(), (int)src1->ne[0]);
// exit(1); //GGML_ASSERT(false);
//}
//++e.ncall;
//printf("%s[%d]: %s, %d x %d, %d\n",__func__,last_call,src0->name,(int)src1->ne[0],(int)src1->ne[1],(int)src1->type);
//for (int row = 0; row < (int)src1->ne[1]; ++row) {
// const float * x = (const float *)src1->data + row * src1->ne[0];
// for (int j = 0; j < (int)src1->ne[0]; ++j) {
// e.values[j] += x[j]*x[j];
// }
//}
//if (e.ncall > last_call) {
// last_call = e.ncall;
// if (last_call % 10 == 0) {
// ik_save_statistics("stats.dat", g_stats, last_call);
// }
//}
}


struct results_log_softmax {
double log_softmax;
Expand Down Expand Up @@ -273,7 +335,9 @@ int main(int argc, char ** argv) {

StatParams sparams;
std::vector<char*> args;
for (int iarg = 1; iarg < argc-1; ++iarg) {
args.push_back(argv[0]);
int iarg = 1;
for (; iarg < argc-1; ++iarg) {
std::string arg{argv[iarg]};
if (arg == "-o" || arg == "--output-file") {
sparams.ofile = argv[++iarg];
Expand All @@ -287,14 +351,20 @@ int main(int argc, char ** argv) {
args.push_back(argv[iarg]);
}
}
if (iarg < argc) {
args.push_back(argv[iarg]);
}

gpt_params params;
params.n_batch = 512;
if (!gpt_params_parse(args.size(), args.data(), params)) {
return 1;
}

ggml_set_stat_collection(ik_collect_imatrix);
g_collector.set_parameters(std::move(sparams));

ggml_set_imatrix_collection(ik_collect_imatrix);
ggml_set_imatrix_collection(ik_collect_imatrix);

params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);
Expand Down Expand Up @@ -340,12 +410,14 @@ int main(int argc, char ** argv) {
if (!OK) {
return 1;
}
auto& stats = ik_get_stats();
int ncall = 0;
for (auto& s : stats) {
ncall = std::max(ncall, s.second.ncall);
}
ik_save_statistics(sparams.ofile.c_str(), stats, ncall);

g_collector.save_imatrix();
//auto& stats = ik_get_stats();
//int ncall = 0;
//for (auto& s : stats) {
// ncall = std::max(ncall, s.second.ncall);
//}
//ik_save_statistics(sparams.ofile.c_str(), stats, ncall);

llama_print_timings(ctx);

Expand Down
2 changes: 1 addition & 1 deletion ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * rest

ggml_collect_imatrix_t g_imatrix_collect = NULL;

void ggml_set_stat_collection(ggml_collect_imatrix_t imatrix_collect) {
void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) {
g_imatrix_collect = imatrix_collect;
}

Expand Down
2 changes: 1 addition & 1 deletion ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2077,7 +2077,7 @@ extern "C" {
// Importance matrix
//
typedef void(*ggml_collect_imatrix_t)(const struct ggml_tensor * src0, const struct ggml_tensor * src1);
GGML_API void ggml_set_stat_collection(ggml_collect_imatrix_t imatrix_collect);
GGML_API void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect);

//
// gguf
Expand Down