|
| 1 | +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +#include <iostream> |
| 15 | +#include <sstream> |
| 16 | + |
| 17 | +#include "fastdeploy/function/reduce.h" |
| 18 | +#include "fastdeploy/function/softmax.h" |
| 19 | +#include "fastdeploy/text.h" |
| 20 | +#include "faster_tokenizer/tokenizers/ernie_faster_tokenizer.h" |
| 21 | +#include "uie.h" |
| 22 | + |
| 23 | +using namespace paddlenlp; |
| 24 | + |
| 25 | +#ifdef WIN32 |
| 26 | +const char sep = '\\'; |
| 27 | +#else |
| 28 | +const char sep = '/'; |
| 29 | +#endif |
| 30 | + |
| 31 | +int main(int argc, char* argv[]) { |
| 32 | + if (argc < 3) { |
| 33 | + std::cout << "Usage: infer_demo path/to/model run_option, " |
| 34 | + "e.g ./infer_demo uie-base 0" |
| 35 | + << std::endl; |
| 36 | + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " |
| 37 | + "with gpu." |
| 38 | + << std::endl; |
| 39 | + return -1; |
| 40 | + } |
| 41 | + auto option = fastdeploy::RuntimeOption(); |
| 42 | + if (std::atoi(argv[2]) == 0) { |
| 43 | + option.UseCpu(); |
| 44 | + } else { |
| 45 | + option.UseGpu(); |
| 46 | + } |
| 47 | + std::string model_dir(argv[1]); |
| 48 | + std::string model_path = model_dir + sep + "inference.pdmodel"; |
| 49 | + std::string param_path = model_dir + sep + "inference.pdiparams"; |
| 50 | + std::string vocab_path = model_dir + sep + "vocab.txt"; |
| 51 | + |
| 52 | + auto predictor = UIEModel(model_path, param_path, vocab_path, 0.5, 128, |
| 53 | + {"时间", "选手", "赛事名称"}, option); |
| 54 | + fastdeploy::FDINFO << "After init predictor" << std::endl; |
| 55 | + std::vector<std::unordered_map<std::string, std::vector<UIEResult>>> results; |
| 56 | + // Named Entity Recognition |
| 57 | + predictor.Predict({"2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷" |
| 58 | + "爱凌以188.25分获得金牌!"}, |
| 59 | + &results); |
| 60 | + std::cout << results << std::endl; |
| 61 | + results.clear(); |
| 62 | + |
| 63 | + // Relation Extraction |
| 64 | + predictor.SetSchema({{"竞赛名称", |
| 65 | + {SchemaNode("主办方"), SchemaNode("承办方"), |
| 66 | + SchemaNode("已举办次数")}}}); |
| 67 | + predictor.Predict( |
| 68 | + {"2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度" |
| 69 | + "公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会" |
| 70 | + "承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。"}, |
| 71 | + &results); |
| 72 | + std::cout << results << std::endl; |
| 73 | + results.clear(); |
| 74 | + |
| 75 | + // Event Extraction |
| 76 | + predictor.SetSchema({{"地震触发词", |
| 77 | + {SchemaNode("地震强度"), SchemaNode("时间"), |
| 78 | + SchemaNode("震中位置"), SchemaNode("震源深度")}}}); |
| 79 | + predictor.Predict( |
| 80 | + {"中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24." |
| 81 | + "34度,东经99.98度)发生3.5级地震,震源深度10千米。"}, |
| 82 | + &results); |
| 83 | + std::cout << results << std::endl; |
| 84 | + results.clear(); |
| 85 | + |
| 86 | + // Opinion Extraction |
| 87 | + predictor.SetSchema( |
| 88 | + {{"评价维度", |
| 89 | + {SchemaNode("观点词"), SchemaNode("情感倾向[正向,负向]")}}}); |
| 90 | + predictor.Predict( |
| 91 | + {"店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队"}, |
| 92 | + &results); |
| 93 | + std::cout << results << std::endl; |
| 94 | + results.clear(); |
| 95 | + |
| 96 | + // Sequence classification |
| 97 | + predictor.SetSchema({"情感倾向[正向,负向]"}); |
| 98 | + predictor.Predict({"这个产品用起来真的很流畅,我非常喜欢"}, &results); |
| 99 | + std::cout << results << std::endl; |
| 100 | + results.clear(); |
| 101 | + |
| 102 | + // Cross task extraction |
| 103 | + |
| 104 | + predictor.SetSchema({{"法院", {}}, |
| 105 | + {"原告", {SchemaNode("委托代理人")}}, |
| 106 | + {"被告", {SchemaNode("委托代理人")}}}); |
| 107 | + predictor.Predict({"北京市海淀区人民法院\n民事判决书\n(199x)" |
| 108 | + "建初字第xxx号\n原告:张三。\n委托代理人李四,北京市 " |
| 109 | + "A律师事务所律师。\n被告:B公司,法定代表人王五,开发公司" |
| 110 | + "总经理。\n委托代理人赵六,北京市 C律师事务所律师。"}, |
| 111 | + &results); |
| 112 | + std::cout << results << std::endl; |
| 113 | + results.clear(); |
| 114 | + return 0; |
| 115 | +} |
0 commit comments