diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 3a739785fe..bd62cbf1f0 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -120,8 +120,12 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + var isBlock bool + jsonData, isBlock, err = relaycommon.ApplyParamOverride(jsonData, info) if err != nil { + if isBlock { + return types.NewError(err, types.ErrorCodeContidionTriggerBlock, types.ErrOptionWithSkipRetry()) + } return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/common/condition.go b/relay/common/condition.go new file mode 100644 index 0000000000..973db77fa6 --- /dev/null +++ b/relay/common/condition.go @@ -0,0 +1,213 @@ +package common + +import ( + "encoding/json" + "fmt" + "github.com/tidwall/gjson" + "regexp" + "strconv" + "strings" +) + +// ConditionChecker 条件检查器接口 +type ConditionChecker interface { + CheckConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) + CheckSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) +} + +// DefaultConditionChecker 默认条件检查器实现 +type DefaultConditionChecker struct{} + +// NewConditionChecker 创建新的条件检查器 +func NewConditionChecker() ConditionChecker { + return &DefaultConditionChecker{} +} + +type ConditionOperation struct { + Path string `json:"path"` // JSON路径 + Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte + Value interface{} `json:"value"` // 匹配的值 + Invert bool `json:"invert"` // 反选功能,true表示取反结果 + PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为 +} + +// CheckConditions 检查多个条件 +func CheckConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) { + checker := NewConditionChecker() + return checker.CheckConditions(jsonStr, conditions, logic) +} + +// CheckSingleCondition 检查单个条件 +func CheckSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { + checker := NewConditionChecker() + return checker.CheckSingleCondition(jsonStr, condition) +} + +func (c *DefaultConditionChecker) CheckConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) { + return checkConditions(jsonStr, conditions, logic) +} + +func (c *DefaultConditionChecker) CheckSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { + return checkSingleCondition(jsonStr, condition) +} + +func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) { + if len(conditions) == 0 { + return true, nil // 没有条件,直接通过 + } + results := make([]bool, len(conditions)) + for i, condition := range conditions { + result, err := checkSingleCondition(jsonStr, condition) + if err != nil { + return false, err + } + results[i] = result + } + + if strings.ToUpper(logic) == "AND" { + for _, result := range results { + if !result { + return false, nil + } + } + return true, nil + } else { + for _, result := range results { + if result { + return true, nil + } + } + return false, nil + } +} + +func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { + // 处理负数索引 + path := processNegativeIndex(jsonStr, condition.Path) + value := gjson.Get(jsonStr, path) + if !value.Exists() { + if condition.PassMissingKey { + return true, nil + } + return false, nil + } + + // 利用gjson的类型解析 + targetBytes, err := json.Marshal(condition.Value) + if err != nil { + return false, fmt.Errorf("failed to marshal condition value: %v", err) + } + targetValue := gjson.ParseBytes(targetBytes) + + result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode)) + if err != nil { + return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err) + } + + if condition.Invert { + result = !result + } + return result, nil +} + +func processNegativeIndex(jsonStr string, path string) string { + re := regexp.MustCompile(`\.(-\d+)`) + matches := re.FindAllStringSubmatch(path, -1) + + if len(matches) == 0 { + return path + } + + result := path + for _, match := range matches { + negIndex := match[1] + index, _ := strconv.Atoi(negIndex) + + arrayPath := strings.Split(path, negIndex)[0] + if strings.HasSuffix(arrayPath, ".") { + arrayPath = arrayPath[:len(arrayPath)-1] + } + + array := gjson.Get(jsonStr, arrayPath) + if array.IsArray() { + length := len(array.Array()) + actualIndex := length + index + if actualIndex >= 0 && actualIndex < length { + result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1) + } + } + } + + return result +} + +// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式 +func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) { + switch mode { + case "full": + return compareEqual(jsonValue, targetValue) + case "prefix": + return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil + case "suffix": + return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil + case "contains": + return strings.Contains(jsonValue.String(), targetValue.String()), nil + case "gt": + return compareNumeric(jsonValue, targetValue, "gt") + case "gte": + return compareNumeric(jsonValue, targetValue, "gte") + case "lt": + return compareNumeric(jsonValue, targetValue, "lt") + case "lte": + return compareNumeric(jsonValue, targetValue, "lte") + default: + return false, fmt.Errorf("unsupported comparison mode: %s", mode) + } +} + +func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) { + // 对布尔值特殊处理 + if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) && + (targetValue.Type == gjson.True || targetValue.Type == gjson.False) { + return jsonValue.Bool() == targetValue.Bool(), nil + } + + // 如果类型不同,报错 + if jsonValue.Type != targetValue.Type { + return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type) + } + + switch jsonValue.Type { + case gjson.True, gjson.False: + return jsonValue.Bool() == targetValue.Bool(), nil + case gjson.Number: + return jsonValue.Num == targetValue.Num, nil + case gjson.String: + return jsonValue.String() == targetValue.String(), nil + default: + return jsonValue.String() == targetValue.String(), nil + } +} + +func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) { + // 只有数字类型才支持数值比较 + if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number { + return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type) + } + + jsonNum := jsonValue.Num + targetNum := targetValue.Num + + switch operator { + case "gt": + return jsonNum > targetNum, nil + case "gte": + return jsonNum >= targetNum, nil + case "lt": + return jsonNum < targetNum, nil + case "lte": + return jsonNum <= targetNum, nil + default: + return false, fmt.Errorf("unsupported numeric operator: %s", operator) + } +} diff --git a/relay/common/override.go b/relay/common/override.go index 212cf7b47a..f527941b78 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -5,22 +5,17 @@ import ( "fmt" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "regexp" - "strconv" - "strings" ) -type ConditionOperation struct { - Path string `json:"path"` // JSON路径 - Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte - Value interface{} `json:"value"` // 匹配的值 - Invert bool `json:"invert"` // 反选功能,true表示取反结果 - PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为 +const InternalPromptTokensKey = "NewAPIInternal" + +type NewAPIInternal struct { + PromptTokens int `json:"PromptTokens"` } type ParamOperation struct { Path string `json:"path"` - Mode string `json:"mode"` // delete, set, move, prepend, append + Mode string `json:"mode"` // delete, set, move, prepend, append, block, pass Value interface{} `json:"value"` KeepOrigin bool `json:"keep_origin"` From string `json:"from,omitempty"` @@ -29,20 +24,22 @@ type ParamOperation struct { Logic string `json:"logic,omitempty"` // AND, OR (默认OR) } -func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { - if len(paramOverride) == 0 { - return jsonData, nil +func ApplyParamOverride(jsonData []byte, Relayinfo *RelayInfo) ([]byte, bool, error) { + if len(Relayinfo.ParamOverride) == 0 { + return jsonData, false, nil } // 尝试断言为操作格式 - if operations, ok := tryParseOperations(paramOverride); ok { + if operations, ok := tryParseOperations(Relayinfo.ParamOverride); ok { + info := &NewAPIInternal{PromptTokens: Relayinfo.PromptTokens} // 使用新方法 - result, err := applyOperations(string(jsonData), operations) - return []byte(result), err + result, isBlock, err := applyOperations(string(jsonData), operations, info) + return []byte(result), isBlock, err } // 直接使用旧方法 - return applyOperationsLegacy(jsonData, paramOverride) + result, err := applyOperationsLegacy(jsonData, Relayinfo.ParamOverride) + return result, false, err } func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { @@ -122,167 +119,6 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, return nil, false } -func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) { - if len(conditions) == 0 { - return true, nil // 没有条件,直接通过 - } - results := make([]bool, len(conditions)) - for i, condition := range conditions { - result, err := checkSingleCondition(jsonStr, condition) - if err != nil { - return false, err - } - results[i] = result - } - - if strings.ToUpper(logic) == "AND" { - for _, result := range results { - if !result { - return false, nil - } - } - return true, nil - } else { - for _, result := range results { - if result { - return true, nil - } - } - return false, nil - } -} - -func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { - // 处理负数索引 - path := processNegativeIndex(jsonStr, condition.Path) - value := gjson.Get(jsonStr, path) - if !value.Exists() { - if condition.PassMissingKey { - return true, nil - } - return false, nil - } - - // 利用gjson的类型解析 - targetBytes, err := json.Marshal(condition.Value) - if err != nil { - return false, fmt.Errorf("failed to marshal condition value: %v", err) - } - targetValue := gjson.ParseBytes(targetBytes) - - result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode)) - if err != nil { - return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err) - } - - if condition.Invert { - result = !result - } - return result, nil -} - -func processNegativeIndex(jsonStr string, path string) string { - re := regexp.MustCompile(`\.(-\d+)`) - matches := re.FindAllStringSubmatch(path, -1) - - if len(matches) == 0 { - return path - } - - result := path - for _, match := range matches { - negIndex := match[1] - index, _ := strconv.Atoi(negIndex) - - arrayPath := strings.Split(path, negIndex)[0] - if strings.HasSuffix(arrayPath, ".") { - arrayPath = arrayPath[:len(arrayPath)-1] - } - - array := gjson.Get(jsonStr, arrayPath) - if array.IsArray() { - length := len(array.Array()) - actualIndex := length + index - if actualIndex >= 0 && actualIndex < length { - result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1) - } - } - } - - return result -} - -// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式 -func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) { - switch mode { - case "full": - return compareEqual(jsonValue, targetValue) - case "prefix": - return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil - case "suffix": - return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil - case "contains": - return strings.Contains(jsonValue.String(), targetValue.String()), nil - case "gt": - return compareNumeric(jsonValue, targetValue, "gt") - case "gte": - return compareNumeric(jsonValue, targetValue, "gte") - case "lt": - return compareNumeric(jsonValue, targetValue, "lt") - case "lte": - return compareNumeric(jsonValue, targetValue, "lte") - default: - return false, fmt.Errorf("unsupported comparison mode: %s", mode) - } -} - -func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) { - // 对布尔值特殊处理 - if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) && - (targetValue.Type == gjson.True || targetValue.Type == gjson.False) { - return jsonValue.Bool() == targetValue.Bool(), nil - } - - // 如果类型不同,报错 - if jsonValue.Type != targetValue.Type { - return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type) - } - - switch jsonValue.Type { - case gjson.True, gjson.False: - return jsonValue.Bool() == targetValue.Bool(), nil - case gjson.Number: - return jsonValue.Num == targetValue.Num, nil - case gjson.String: - return jsonValue.String() == targetValue.String(), nil - default: - return jsonValue.String() == targetValue.String(), nil - } -} - -func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) { - // 只有数字类型才支持数值比较 - if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number { - return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type) - } - - jsonNum := jsonValue.Num - targetNum := targetValue.Num - - switch operator { - case "gt": - return jsonNum > targetNum, nil - case "gte": - return jsonNum >= targetNum, nil - case "lt": - return jsonNum < targetNum, nil - case "lte": - return jsonNum <= targetNum, nil - default: - return false, fmt.Errorf("unsupported numeric operator: %s", operator) - } -} - // applyOperationsLegacy 原参数覆盖方法 func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { reqMap := make(map[string]interface{}) @@ -298,22 +134,35 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{} return json.Marshal(reqMap) } -func applyOperations(jsonStr string, operations []ParamOperation) (string, error) { - result := jsonStr +func applyOperations(jsonStr string, operations []ParamOperation, info *NewAPIInternal) (string, bool, error) { + jsonStrWithTokens, err := sjson.Set(jsonStr, InternalPromptTokensKey, info) + if err != nil { + return "", false, fmt.Errorf("failed to add %s: %v", InternalPromptTokensKey, err) + } + result := jsonStrWithTokens for _, op := range operations { // 检查条件是否满足 - ok, err := checkConditions(result, op.Conditions, op.Logic) + ok, err := CheckConditions(result, op.Conditions, op.Logic) if err != nil { - return "", err + return "", false, err } if !ok { continue // 条件不满足,跳过当前操作 } + // 处理block和pass操作 + if op.Mode == "block" { + blockMessage := "request blocked by param override conditions" + return result, true, fmt.Errorf(blockMessage) + } + if op.Mode == "pass" { + // 移除添加的内部字段 + result, _ = sjson.Delete(result, InternalPromptTokensKey) + return result, false, nil // 直接通过 + } // 处理路径中的负数索引 opPath := processNegativeIndex(result, op.Path) opFrom := processNegativeIndex(result, op.From) opTo := processNegativeIndex(result, op.To) - switch op.Mode { case "delete": result, err = sjson.Delete(result, opPath) @@ -329,13 +178,15 @@ func applyOperations(jsonStr string, operations []ParamOperation) (string, error case "append": result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false) default: - return "", fmt.Errorf("unknown operation: %s", op.Mode) + return "", false, fmt.Errorf("unknown operation: %s", op.Mode) } if err != nil { - return "", fmt.Errorf("operation %s failed: %v", op.Mode, err) + return "", false, fmt.Errorf("operation %s failed: %v", op.Mode, err) } } - return result, nil + // 移除添加的内部字段 + result, _ = sjson.Delete(result, InternalPromptTokensKey) + return result, false, nil } func moveValue(jsonStr, fromPath, toPath string) (string, error) { diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index a3ddf6d493..fcd988c7f0 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -143,8 +143,12 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + var isBlock bool + jsonData, isBlock, err = relaycommon.ApplyParamOverride(jsonData, info) if err != nil { + if isBlock { + return types.NewError(err, types.ErrorCodeContidionTriggerBlock, types.ErrOptionWithSkipRetry()) + } return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 1410da606d..9188e62ac7 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -155,8 +155,12 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + var isBlock bool + jsonData, isBlock, err = relaycommon.ApplyParamOverride(jsonData, info) if err != nil { + if isBlock { + return types.NewError(err, types.ErrorCodeContidionTriggerBlock, types.ErrOptionWithSkipRetry()) + } return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/image_handler.go b/relay/image_handler.go index 9c873d47f3..676efa9505 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -67,8 +67,12 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + var isBlock bool + jsonData, isBlock, err = relaycommon.ApplyParamOverride(jsonData, info) if err != nil { + if isBlock { + return types.NewError(err, types.ErrorCodeContidionTriggerBlock, types.ErrOptionWithSkipRetry()) + } return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 46d2e25f6f..f7a299d1d8 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -59,8 +59,12 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + var isBlock bool + jsonData, isBlock, err = relaycommon.ApplyParamOverride(jsonData, info) if err != nil { + if isBlock { + return types.NewError(err, types.ErrorCodeContidionTriggerBlock, types.ErrOptionWithSkipRetry()) + } return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 6958f96ef9..77b0dbad25 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -65,8 +65,12 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + var isBlock bool + jsonData, isBlock, err = relaycommon.ApplyParamOverride(jsonData, info) if err != nil { + if isBlock { + return types.NewError(err, types.ErrorCodeContidionTriggerBlock, types.ErrOptionWithSkipRetry()) + } return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/types/error.go b/types/error.go index 77a56dd256..4ac21b2475 100644 --- a/types/error.go +++ b/types/error.go @@ -50,6 +50,7 @@ const ( // channel error ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" + ErrorCodeContidionTriggerBlock ErrorCode = "channel:condition_trigger_block" ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid" ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"