Skip to content

Commit cb361de

Browse files
👥 Add VariableAggregatorNodeDataConverter and generics
VariableAggregatorNodeDataConverter and generics
1 parent b424767 commit cb361de

File tree

12 files changed

+176
-40
lines changed

12 files changed

+176
-40
lines changed

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717

1818
<properties>
19-
<maven.compiler.source>8</maven.compiler.source>
20-
<maven.compiler.target>8</maven.compiler.target>
19+
<maven.compiler.source>17</maven.compiler.source>
20+
<maven.compiler.target>17</maven.compiler.target>
2121
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
2222
</properties>
2323

spring-ai-alibaba-graph/spring-ai-alibaba-graph-samples/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717

1818
<properties>
19-
<maven.compiler.source>8</maven.compiler.source>
20-
<maven.compiler.target>8</maven.compiler.target>
19+
<maven.compiler.source>17</maven.compiler.source>
20+
<maven.compiler.target>17</maven.compiler.target>
2121
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
2222
</properties>
2323

spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/NodeType.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ public enum NodeType {
1414

1515
RETRIEVER("RETRIEVER", "knowledge-retrieval"),
1616

17+
AGGREGATOR("AGGREGATOR", "variable-aggregator"),
18+
1719
HUMAN("HUMAN", "unsupported"),;
1820

1921
private String value;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package com.alibaba.cloud.ai.model.workflow.nodedata;
2+
3+
import com.alibaba.cloud.ai.model.Variable;
4+
import com.alibaba.cloud.ai.model.VariableSelector;
5+
import com.alibaba.cloud.ai.model.workflow.NodeData;
6+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
7+
import com.fasterxml.jackson.annotation.JsonInclude;
8+
import com.fasterxml.jackson.annotation.JsonProperty;
9+
import lombok.Data;
10+
import lombok.EqualsAndHashCode;
11+
import lombok.experimental.Accessors;
12+
13+
import java.util.List;
14+
15+
/***
16+
*
17+
*/
18+
@EqualsAndHashCode(callSuper = true)
19+
@Data
20+
@Accessors(chain = true)
21+
@JsonInclude(JsonInclude.Include.NON_NULL)
22+
@JsonIgnoreProperties(ignoreUnknown = true)
23+
public class VariableAggregatorNodeData extends NodeData {
24+
25+
private String type;
26+
private String title;
27+
private String desc;
28+
private List<List<String>> variables;
29+
@JsonProperty("output_type")
30+
private String output_type;
31+
private boolean selected;
32+
@JsonProperty("advanced_settings")
33+
private AdvancedSettings advanced_settings;
34+
35+
public VariableAggregatorNodeData(List<VariableSelector> inputs, List<Variable> outputs) {
36+
super(inputs, outputs);
37+
}
38+
39+
40+
41+
@Data
42+
@JsonInclude(JsonInclude.Include.NON_NULL)
43+
@JsonIgnoreProperties(ignoreUnknown = true)
44+
public static class Groups{
45+
@JsonProperty("output_type")
46+
private String output_type;
47+
private List<List<String>> variables;
48+
@JsonProperty("group_name")
49+
private String group_name;
50+
private String groupId;
51+
}
52+
53+
@Data
54+
@JsonInclude(JsonInclude.Include.NON_NULL)
55+
@JsonIgnoreProperties(ignoreUnknown = true)
56+
public static class AdvancedSettings {
57+
@JsonProperty("group_enabled")
58+
private boolean group_enabled;
59+
private List<Groups> groups;
60+
}
61+
62+
63+
}

spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/NodeDataConverter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
* NodeDataConverter defined the mutual conversion between specific DSL data and
99
* {@link NodeData}
1010
*/
11-
public interface NodeDataConverter {
11+
public interface NodeDataConverter<T extends NodeData> {
1212

1313
/**
1414
* Judge if this converter support this node type
@@ -22,13 +22,13 @@ public interface NodeDataConverter {
2222
* @param data DSL data
2323
* @return converted {@link NodeData}
2424
*/
25-
NodeData parseDifyData(Map<String, Object> data);
25+
T parseDifyData(Map<String, Object> data);
2626

2727
/**
2828
* Dump NodeData to DSL data
2929
* @param nodeData {@link NodeData}
3030
* @return converted DSL data
3131
*/
32-
Map<String, Object> dumpDifyData(NodeData nodeData);
32+
Map<String, Object> dumpDifyData(T nodeData);
3333

3434
}

spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/AnswerNodeDataConverter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
import java.util.Map;
1515

1616
@Component
17-
public class AnswerNodeDataConverter implements NodeDataConverter {
17+
public class AnswerNodeDataConverter implements NodeDataConverter<AnswerNodeData> {
1818

1919
@Override
2020
public Boolean supportType(String nodeType) {
2121
return NodeType.ANSWER.value().equals(nodeType);
2222
}
2323

2424
@Override
25-
public NodeData parseDifyData(Map<String, Object> data) {
25+
public AnswerNodeData parseDifyData(Map<String, Object> data) {
2626
String difyTmpl = (String) data.get("answer");
2727
List<String> variables = new ArrayList<>();
2828
String tmpl = StringTemplateUtil.fromDifyTmpl(difyTmpl, variables);
@@ -34,7 +34,7 @@ public NodeData parseDifyData(Map<String, Object> data) {
3434
}
3535

3636
@Override
37-
public Map<String, Object> dumpDifyData(NodeData nodeData) {
37+
public Map<String, Object> dumpDifyData(AnswerNodeData nodeData) {
3838
AnswerNodeData answerNodeData = (AnswerNodeData) nodeData;
3939
Map<String, Object> data = new HashMap<>();
4040
String difyTmpl = StringTemplateUtil.toDifyTmpl(answerNodeData.getAnswer());

spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/CodeNodeDataConverter.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
import java.util.*;
1313

1414
@Component
15-
public class CodeNodeDataConverter implements NodeDataConverter {
15+
public class CodeNodeDataConverter implements NodeDataConverter<CodeNodeData> {
1616

1717
@Override
1818
public Boolean supportType(String nodeType) {
1919
return NodeType.CODE.value().equals(nodeType);
2020
}
2121

2222
@Override
23-
public NodeData parseDifyData(Map<String, Object> data) {
23+
public CodeNodeData parseDifyData(Map<String, Object> data) {
2424
List<Map<String, Object>> variables = (List<Map<String, Object>>) data.get("variables");
2525
List<VariableSelector> inputs = variables.stream().map(variable -> {
2626
List<String> selector = (List<String>) variable.get("value_selector");
@@ -40,11 +40,10 @@ public NodeData parseDifyData(Map<String, Object> data) {
4040
}
4141

4242
@Override
43-
public Map<String, Object> dumpDifyData(NodeData nodeData) {
44-
CodeNodeData codeNodeData = (CodeNodeData) nodeData;
43+
public Map<String, Object> dumpDifyData(CodeNodeData nodeData) {
4544
Map<String, Object> data = new HashMap<>();
46-
data.put("code", codeNodeData.getCode());
47-
data.put("code_language", codeNodeData.getCodeLanguage());
45+
data.put("code", nodeData.getCode());
46+
data.put("code_language", nodeData.getCodeLanguage());
4847
List<Map<String, Object>> inputVars = new ArrayList<>();
4948
nodeData.getInputs().forEach(v -> {
5049
inputVars.add(Map.of("variable", v.getLabel(), "value_selector", List.of(v.getNamespace(), v.getName())));

spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/EndNodeDataConverter.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
import java.util.Map;
1313

1414
@Component
15-
public class EndNodeDataConverter implements NodeDataConverter {
15+
public class EndNodeDataConverter implements NodeDataConverter<EndNodeData> {
1616

1717
@Override
1818
public Boolean supportType(String nodeType) {
1919
return NodeType.END.value().equals(nodeType);
2020
}
2121

2222
@Override
23-
public NodeData parseDifyData(Map<String, Object> data) {
23+
public EndNodeData parseDifyData(Map<String, Object> data) {
2424
List<Map<String, Object>> outputsMap = (List<Map<String, Object>>) data.get("outputs");
2525
List<VariableSelector> inputs = outputsMap.stream().map(output -> {
2626
List<String> valueSelector = (List<String>) output.get("value_selector");
@@ -31,10 +31,9 @@ public NodeData parseDifyData(Map<String, Object> data) {
3131
}
3232

3333
@Override
34-
public Map<String, Object> dumpDifyData(NodeData nodeData) {
35-
EndNodeData endNodeData = (EndNodeData) nodeData;
34+
public Map<String, Object> dumpDifyData(EndNodeData nodeData) {
3635
Map<String, Object> data = new HashMap<>();
37-
List<Map<String, Object>> outputsMap = endNodeData.getInputs()
36+
List<Map<String, Object>> outputsMap = nodeData.getInputs()
3837
.stream()
3938
.map(input -> Map.of("value_selector", List.of(input.getNamespace(), input.getName()), "variable",
4039
input.getLabel()))

spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/LLMNodeDataConverter.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
import java.util.Map;
1919

2020
@Component
21-
public class LLMNodeDataConverter implements NodeDataConverter {
21+
public class LLMNodeDataConverter implements NodeDataConverter<LLMNodeData> {
2222

2323
@Override
2424
public Boolean supportType(String nodeType) {
2525
return NodeType.LLM.value().equals(nodeType);
2626
}
2727

2828
@Override
29-
public NodeData parseDifyData(Map<String, Object> data) {
29+
public LLMNodeData parseDifyData(Map<String, Object> data) {
3030
List<VariableSelector> inputs = new ArrayList<>();
3131
// convert prompt template
3232
Map<String, Object> context = (Map<String, Object>) data.get("context");
@@ -89,8 +89,7 @@ public NodeData parseDifyData(Map<String, Object> data) {
8989
}
9090

9191
@Override
92-
public Map<String, Object> dumpDifyData(NodeData nodeData) {
93-
LLMNodeData llmNodeData = (LLMNodeData) nodeData;
92+
public Map<String, Object> dumpDifyData(LLMNodeData nodeData) {
9493
Map<String, Object> data = new HashMap<>();
9594
ObjectMapper objectMapper = new ObjectMapper();
9695
objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
@@ -101,19 +100,19 @@ public Map<String, Object> dumpDifyData(NodeData nodeData) {
101100

102101
));
103102
// put memory
104-
LLMNodeData.MemoryConfig memory = llmNodeData.getMemoryConfig();
103+
LLMNodeData.MemoryConfig memory = nodeData.getMemoryConfig();
105104
if (memory != null) {
106105
data.put("memory",
107106
Map.of("query_prompt_template", StringTemplateUtil.toDifyTmpl(memory.getLastMessageTemplate()),
108107
"role_prefix", Map.of("assistant", "", "user", ""), "window",
109108
Map.of("enabled", memory.getWindowEnabled(), "size", memory.getWindowSize())));
110109
}
111110
// put model
112-
LLMNodeData.ModelConfig model = llmNodeData.getModel();
111+
LLMNodeData.ModelConfig model = nodeData.getModel();
113112
data.put("model", Map.of("mode", model.getMode(), "name", model.getName(), "provider", model.getProvider(),
114113
"completion_params", objectMapper.convertValue(model.getCompletionParams(), Map.class)));
115114
// put prompt template
116-
List<LLMNodeData.PromptTemplate> tmplList = llmNodeData.getPromptTemplate();
115+
List<LLMNodeData.PromptTemplate> tmplList = nodeData.getPromptTemplate();
117116
List<Map<String, String>> difyTmplList = tmplList.stream().map(tmpl -> {
118117
String difyTmpl = StringTemplateUtil.toDifyTmpl(tmpl.getText());
119118
return Map.of("role", tmpl.getRole(), "text", difyTmpl);

spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/RetrieverNodeDataConverter.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
import java.util.Optional;
1414

1515
@Component
16-
public class RetrieverNodeDataConverter implements NodeDataConverter {
16+
public class RetrieverNodeDataConverter implements NodeDataConverter<RetrieverNodeData> {
1717

1818
@Override
1919
public Boolean supportType(String nodeType) {
2020
return NodeType.RETRIEVER.value().equals(nodeType);
2121
}
2222

2323
@Override
24-
public NodeData parseDifyData(Map<String, Object> data) {
24+
public RetrieverNodeData parseDifyData(Map<String, Object> data) {
2525
List<String> selector = (List<String>) data.get("query_variable_selector");
2626
List<VariableSelector> inputs = List.of(new VariableSelector(selector.get(0), selector.get(1)));
2727
Map<String, Object> configMap = (Map<String, Object>) data.get("multiple_retrieval_config");
@@ -43,18 +43,17 @@ public NodeData parseDifyData(Map<String, Object> data) {
4343
}
4444

4545
@Override
46-
public Map<String, Object> dumpDifyData(NodeData nodeData) {
46+
public Map<String, Object> dumpDifyData(RetrieverNodeData nodeData) {
4747
Map<String, Object> data = new HashMap<>();
48-
RetrieverNodeData retrieverNodeData = (RetrieverNodeData) nodeData;
49-
RetrieverNodeData.RerankOptions rerankConfig = retrieverNodeData.getMultipleRetrievalOptions();
48+
RetrieverNodeData.RerankOptions rerankConfig = nodeData.getMultipleRetrievalOptions();
5049
Map<String, Object> configMap = Map.of("reranking_enabled", rerankConfig.getEnableRerank(), "reranking_mode",
5150
"reranking_model", "reranking_model",
5251
Map.of("model", rerankConfig.getRerankModelName(), "provider", rerankConfig.getRerankModelProvider()),
5352
"score_threshold", rerankConfig.getRerankThreshold(), "top_k", rerankConfig.getRerankTopK());
5453
data.put("dataset_ids", List.of());
5554
data.put("multiple_retrieval_config", configMap);
56-
data.put("query_variable_selector", List.of(retrieverNodeData.getInputs().get(0).getNamespace(),
57-
retrieverNodeData.getInputs().get(0).getName()));
55+
data.put("query_variable_selector", List.of(nodeData.getInputs().get(0).getNamespace(),
56+
nodeData.getInputs().get(0).getName()));
5857
data.put("retrieval_mode", "multiple");
5958
return data;
6059
}

0 commit comments

Comments
 (0)