Skip to content

Commit cbb7655

Browse files
committed
[improvement][chat]Optimize NL2SQL parsing logic.
1 parent 996cb3d commit cbb7655

File tree

7 files changed

+94
-168
lines changed

7 files changed

+94
-168
lines changed

chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.tencent.supersonic.chat.server.parser;
22

3+
import com.google.common.collect.Lists;
34
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
45
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
56
import com.tencent.supersonic.chat.server.pojo.ChatContext;
@@ -15,10 +16,12 @@
1516
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
1617
import com.tencent.supersonic.common.util.ChatAppManager;
1718
import com.tencent.supersonic.common.util.ContextUtils;
19+
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
1820
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
1921
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
2022
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
2123
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
24+
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
2225
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
2326
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
2427
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -35,6 +38,7 @@
3538
import lombok.extern.slf4j.Slf4j;
3639
import org.slf4j.Logger;
3740
import org.slf4j.LoggerFactory;
41+
import org.springframework.util.CollectionUtils;
3842

3943
import java.util.*;
4044
import java.util.stream.Collectors;
@@ -78,27 +82,24 @@ public void parse(ParseContext parseContext) {
7882
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
7983
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
8084

81-
// inject semantic parse saved by in the chat context
82-
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
83-
ChatContext chatCtx =
84-
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
85-
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
86-
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
87-
}
88-
89-
// for every requested dataSet, recursively invoke rule-based parser
90-
// with different mapModes, unless any valid semantic parse is derived.
85+
// for every requested dataSet, recursively invoke rule-based parser with different
86+
// mapModes
9187
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
9288
for (Long datasetId : requestedDatasets) {
9389
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
94-
ChatParseResp parseResp = parseContext.getResponse();
95-
for (MapModeEnum mode : MapModeEnum.values()) {
90+
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
91+
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.STRICT, MapModeEnum.MODERATE)) {
9692
queryNLReq.setMapModeEnum(mode);
9793
doParse(queryNLReq, parseResp);
98-
if (!parseResp.getSelectedParses().isEmpty()) {
99-
break;
100-
}
10194
}
95+
if (parseResp.getSelectedParses().isEmpty()) {
96+
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
97+
doParse(queryNLReq, parseResp);
98+
}
99+
List<SemanticParseInfo> sortedParses = parseResp.getSelectedParses().stream()
100+
.sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1)
101+
.collect(Collectors.toList());
102+
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
102103
}
103104
}
104105

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
package com.tencent.supersonic.chat.server.processor.parse;
22

33
import com.tencent.supersonic.chat.server.pojo.ParseContext;
4-
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
5-
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
64
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
7-
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult;
85
import lombok.extern.slf4j.Slf4j;
96

107
import java.util.*;
@@ -18,50 +15,12 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
1815
@Override
1916
public void process(ParseContext parseContext) {
2017
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
21-
22-
selectedParses.sort((o1, o2) -> {
23-
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
24-
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
25-
26-
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
27-
if (difference == 0) {
28-
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
29-
if (difference == 0) {
30-
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
31-
}
32-
if (difference == 0) {
33-
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
34-
}
35-
}
36-
return difference >= 0 ? -1 : 1;
37-
});
18+
selectedParses.sort(new SemanticParseInfo.SemanticParseComparator());
3819
// re-assign parseId
3920
for (int i = 0; i < selectedParses.size(); i++) {
4021
SemanticParseInfo parseInfo = selectedParses.get(i);
4122
parseInfo.setId(i + 1);
4223
}
4324
}
4425

45-
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
46-
double maxMetricSimilarity = 0;
47-
double maxDatasetSimilarity = 0;
48-
double totalSimilarity = 0;
49-
long maxMetricUseCnt = 0L;
50-
for (SchemaElementMatch match : elementMatches) {
51-
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
52-
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
53-
}
54-
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
55-
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
56-
if (Objects.nonNull(match.getElement().getUseCnt())) {
57-
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
58-
}
59-
}
60-
totalSimilarity += match.getSimilarity();
61-
}
62-
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
63-
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
64-
.build();
65-
}
66-
6726
}

headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.tencent.supersonic.common.pojo.enums.FilterType;
1010
import com.tencent.supersonic.common.pojo.enums.QueryType;
1111
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
12+
import lombok.Builder;
1213
import lombok.Data;
1314

1415
import java.util.Comparator;
@@ -46,8 +47,58 @@ public class SemanticParseInfo {
4647
private String textInfo;
4748
private Map<String, Object> properties = Maps.newHashMap();
4849

49-
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
50+
@Data
51+
@Builder
52+
public static class DataSetMatchResult {
53+
private double maxMetricSimilarity;
54+
private double maxDatesetSimilarity;
55+
private double totalSimilarity;
56+
private long maxMetricUseCnt;
57+
}
58+
59+
public static class SemanticParseComparator implements Comparator<SemanticParseInfo> {
60+
@Override
61+
public int compare(SemanticParseInfo o1, SemanticParseInfo o2) {
62+
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
63+
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
64+
65+
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
66+
if (difference == 0) {
67+
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
68+
if (difference == 0) {
69+
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
70+
}
71+
if (difference == 0) {
72+
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
73+
}
74+
}
75+
return difference >= 0 ? -1 : 1;
76+
}
77+
78+
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
79+
double maxMetricSimilarity = 0;
80+
double maxDatasetSimilarity = 0;
81+
double totalSimilarity = 0;
82+
long maxMetricUseCnt = 0L;
83+
for (SchemaElementMatch match : elementMatches) {
84+
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
85+
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
86+
}
87+
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
88+
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
89+
if (Objects.nonNull(match.getElement().getUseCnt())) {
90+
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
91+
}
92+
}
93+
totalSimilarity += match.getSimilarity();
94+
}
95+
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
96+
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
97+
.build();
98+
}
99+
}
50100

101+
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
51102
@Override
52103
public int compare(SchemaElement o1, SchemaElement o2) {
53104
if (o1.getOrder() != o2.getOrder()) {
@@ -93,4 +144,19 @@ public long getMetricLimit() {
93144
}
94145
return limit;
95146
}
147+
148+
@Override
149+
public boolean equals(Object o) {
150+
if (this == o)
151+
return true;
152+
if (o == null || getClass() != o.getClass())
153+
return false;
154+
SemanticParseInfo that = (SemanticParseInfo) o;
155+
return Objects.equals(textInfo, that.textInfo);
156+
}
157+
158+
@Override
159+
public int hashCode() {
160+
return Objects.hashCode(textInfo);
161+
}
96162
}

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java

Lines changed: 0 additions & 13 deletions
This file was deleted.

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
44
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
55
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
6+
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
67
import com.tencent.supersonic.headless.chat.ChatQueryContext;
78
import lombok.extern.slf4j.Slf4j;
89
import org.apache.commons.collections.CollectionUtils;
@@ -36,8 +37,9 @@ public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds
3637
}
3738

3839
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
39-
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
40-
Entry<Long, DataSetMatchResult> selectedDataset =
40+
Map<Long, SemanticParseInfo.DataSetMatchResult> dataSetMatchRet =
41+
getDataSetMatchResult(schemaMap);
42+
Entry<Long, SemanticParseInfo.DataSetMatchResult> selectedDataset =
4143
dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> {
4244
double difference = o1.getValue().getMaxDatesetSimilarity()
4345
- o2.getValue().getMaxDatesetSimilarity();
@@ -63,8 +65,9 @@ protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
6365
return null;
6466
}
6567

66-
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
67-
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
68+
protected Map<Long, SemanticParseInfo.DataSetMatchResult> getDataSetMatchResult(
69+
SchemaMapInfo schemaMap) {
70+
Map<Long, SemanticParseInfo.DataSetMatchResult> dateSetMatchRet = new HashMap<>();
6871
for (Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches()
6972
.entrySet()) {
7073
double maxMetricSimilarity = 0;
@@ -84,7 +87,8 @@ protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo sche
8487
totalSimilarity += match.getSimilarity();
8588
}
8689
dateSetMatchRet.put(entry.getKey(),
87-
DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
90+
SemanticParseInfo.DataSetMatchResult.builder()
91+
.maxMetricSimilarity(maxMetricSimilarity)
8892
.maxDatesetSimilarity(maxDatasetSimilarity)
8993
.totalSimilarity(totalSimilarity).build());
9094
}

launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ boolean checkNeedToRun() {
129129
public void addSampleChats(Integer agentId) {
130130
Long chatId = chatManageService.addChat(defaultUser, "样例对话1", agentId);
131131
submitText(chatId.intValue(), agentId, "超音数 访问次数");
132-
submitText(chatId.intValue(), agentId, "按部门统计");
133-
submitText(chatId.intValue(), agentId, "查询近30天");
132+
submitText(chatId.intValue(), agentId, "按部门统计近7天访问次数");
134133
submitText(chatId.intValue(), agentId, "alice 停留时长");
135134
submitText(chatId.intValue(), agentId, "访问次数最高的部门");
136135
}

launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java

Lines changed: 0 additions & 90 deletions
This file was deleted.

0 commit comments

Comments
 (0)