diff --git a/README.en.md b/README.en.md index 69fd32f8ba..2054cc9ba2 100644 --- a/README.en.md +++ b/README.en.md @@ -82,7 +82,7 @@ New API offers a wide range of features, please refer to [Features Introduction] 7. ⚖️ Support for weighted random channel selection 8. 📈 Data dashboard (console) 9. 🔒 Token grouping and model restrictions -10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC) +10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC, Discord) 11. 🔄 Support for Rerank models (Cohere and Jina), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank) 12. ⚡ Support for OpenAI Realtime API (including Azure channels), [API Documentation](https://docs.newapi.pro/api/openai-realtime) 13. ⚡ Support for Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat) diff --git a/README.md b/README.md index 45b048340e..eddcfbe94f 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do 7. ⚖️ 支持渠道加权随机 8. 📈 数据看板(控制台) 9. 🔒 令牌分组、模型限制 -10. 🤖 支持更多授权登陆方式(LinuxDO,Telegram、OIDC) +10. 🤖 支持更多授权登陆方式(LinuxDO、Telegram、OIDC、Discord) 11. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank) 12. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime) 13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat) diff --git a/common/constants.go b/common/constants.go index 3052241150..cf59ed935d 100644 --- a/common/constants.go +++ b/common/constants.go @@ -43,6 +43,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var DiscordOAuthEnabled = false var LinuxDOOAuthEnabled = false var WeChatAuthEnabled = false var TelegramOAuthEnabled = false @@ -81,6 +82,8 @@ var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" +var DiscordClientId = "" +var DiscordClientSecret = "" var LinuxDOClientId = "" var LinuxDOClientSecret = "" diff --git a/controller/discord.go b/controller/discord.go new file mode 100644 index 0000000000..b786932338 --- /dev/null +++ b/controller/discord.go @@ -0,0 +1,224 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "one-api/common" + "one-api/model" + "one-api/setting" + "strconv" + "strings" + "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +type DiscordResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type DiscordUser struct { + UID string `json:"id"` + ID string `json:"username"` + Name string `json:"global_name"` +} + +func getDiscordUserInfoByCode(code string) (*DiscordUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + + values := url.Values{} + values.Set("client_id", common.DiscordClientId) + values.Set("client_secret", common.DiscordClientSecret) + values.Set("code", code) + values.Set("grant_type", "authorization_code") + values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", setting.ServerAddress)) + formData := values.Encode() + req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + common.SysLog(err.Error()) + return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!") + } + defer res.Body.Close() + var discordResponse DiscordResponse + err = json.NewDecoder(res.Body).Decode(&discordResponse) + if err != nil { + return nil, err + } + + if discordResponse.AccessToken == "" { + common.SysError("Discord 获取 Token 失败,请检查设置!") + return nil, errors.New("Discord 获取 Token 失败,请检查设置!") + } + + req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + common.SysLog(err.Error()) + return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!") + } + defer res2.Body.Close() + if res2.StatusCode != http.StatusOK { + common.SysError("Discord 获取用户信息失败!请检查设置!") + return nil, errors.New("Discord 获取用户信息失败!请检查设置!") + } + + var discordUser DiscordUser + err = json.NewDecoder(res2.Body).Decode(&discordUser) + if err != nil { + return nil, err + } + if discordUser.UID == "" || discordUser.ID == "" { + common.SysError("Discord 获取用户信息为空!请检查设置!") + return nil, errors.New("Discord 获取用户信息为空!请检查设置!") + } + return &discordUser, nil +} + +func DiscordOAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + DiscordBind(c) + return + } + if !common.DiscordOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 discord 登录以及注册", + }) + return + } + code := c.Query("code") + discordUser, err := getDiscordUserInfoByCode(code) + if err != nil { + common.ApiError(c, err) + return + } + user := model.User{ + DiscordId: discordUser.UID, + } + if model.IsDiscordIdAlreadyTaken(user.DiscordId) { + err := user.FillUserByDiscordId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if common.RegisterEnabled { + if discordUser.ID != "" { + user.Username = discordUser.ID + } else { + user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if discordUser.Name != "" { + user.DisplayName = discordUser.Name + } else { + user.DisplayName = "Discord User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + setupLogin(&user, c) +} + +func DiscordBind(c *gin.Context) { + if !common.DiscordOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Discord 登录以及注册", + }) + return + } + code := c.Query("code") + discordUser, err := getDiscordUserInfoByCode(code) + if err != nil { + common.ApiError(c, err) + return + } + user := model.User{ + DiscordId: discordUser.UID, + } + if model.IsDiscordIdAlreadyTaken(user.DiscordId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 Discord 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + common.ApiError(c, err) + return + } + user.DiscordId = discordUser.UID + err = user.Update(false) + if err != nil { + common.ApiError(c, err) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/misc.go b/controller/misc.go index a3ed9be9ac..4274a2b9f4 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -46,6 +46,8 @@ func GetStatus(c *gin.Context) { "email_verification": common.EmailVerificationEnabled, "github_oauth": common.GitHubOAuthEnabled, "github_client_id": common.GitHubClientId, + "discord_oauth": common.DiscordOAuthEnabled, + "discord_client_id": common.DiscordClientId, "linuxdo_oauth": common.LinuxDOOAuthEnabled, "linuxdo_client_id": common.LinuxDOClientId, "telegram_oauth": common.TelegramOAuthEnabled, diff --git a/controller/option.go b/controller/option.go index decdb0d405..f2f52a9379 100644 --- a/controller/option.go +++ b/controller/option.go @@ -54,6 +54,14 @@ func UpdateOption(c *gin.Context) { }) return } + case "DiscordOAuthEnabled": + if option.Value == "true" && common.DiscordClientId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 Discord OAuth,请先填入 Discord Client Id 以及 Discord Client Secret!", + }) + return + } case "oidc.enabled": if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" { c.JSON(http.StatusOK, gin.H{ diff --git a/docs/api/web_api.md b/docs/api/web_api.md index e64fd3594f..aa88a606c6 100644 --- a/docs/api/web_api.md +++ b/docs/api/web_api.md @@ -42,6 +42,7 @@ | 方法 | 路径 | 鉴权 | 说明 | |------|------|------|------| | GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 | +| GET | /api/oauth/discord | 公开 | Discord 通用 OAuth 跳转 | | GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 | | GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 | | GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 | diff --git a/i18n/zh-cn.json b/i18n/zh-cn.json index 7b57b51ac9..6d1be799a9 100644 --- a/i18n/zh-cn.json +++ b/i18n/zh-cn.json @@ -3,6 +3,7 @@ "登 录": "登 录", "使用 微信 继续": "使用 微信 继续", "使用 GitHub 继续": "使用 GitHub 继续", + "使用 Discord 继续": "使用 Discord 继续", "使用 LinuxDO 继续": "使用 LinuxDO 继续", "使用 邮箱或用户名 登录": "使用 邮箱或用户名 登录", "没有账户?": "没有账户?", @@ -1032,6 +1033,7 @@ "添加额度": "添加额度", "第三方账户绑定状态(只读)": "第三方账户绑定状态(只读)", "已绑定的 GitHub 账户": "已绑定的 GitHub 账户", + "已绑定的 Discord 账户": "已绑定的 Discord 账户", "已绑定的 OIDC 账户": "已绑定的 OIDC 账户", "已绑定的微信账户": "已绑定的微信账户", "已绑定的邮箱账户": "已绑定的邮箱账户", diff --git a/model/option.go b/model/option.go index 05b99b41af..fe8d603e08 100644 --- a/model/option.go +++ b/model/option.go @@ -36,6 +36,7 @@ func InitOptionMap() { common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) + common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled) common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) @@ -88,6 +89,8 @@ func InitOptionMap() { common.OptionMap["PayMethods"] = setting.PayMethods2JsonString() common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientSecret"] = "" + common.OptionMap["DiscordClientId"] = "" + common.OptionMap["DiscordClientSecret"] = "" common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotName"] = "" common.OptionMap["WeChatServerAddress"] = "" @@ -214,6 +217,8 @@ func updateOptionMap(key string, value string) (err error) { common.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": common.GitHubOAuthEnabled = boolValue + case "DiscordOAuthEnabled": + common.DiscordOAuthEnabled = boolValue case "LinuxDOOAuthEnabled": common.LinuxDOOAuthEnabled = boolValue case "WeChatAuthEnabled": @@ -332,6 +337,10 @@ func updateOptionMap(key string, value string) (err error) { common.GitHubClientId = value case "GitHubClientSecret": common.GitHubClientSecret = value + case "DiscordClientId": + common.DiscordClientId = value + case "DiscordClientSecret": + common.DiscordClientSecret = value case "LinuxDOClientId": common.LinuxDOClientId = value case "LinuxDOClientSecret": diff --git a/model/user.go b/model/user.go index 6021f495c0..15f38a8104 100644 --- a/model/user.go +++ b/model/user.go @@ -25,6 +25,7 @@ type User struct { Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` + DiscordId string `json:"discord_id" gorm:"column:discord_id;index"` OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` @@ -451,6 +452,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByDiscordId() error { + if user.DiscordId == "" { + return errors.New("discord id 为空!") + } + DB.Where(User{DiscordId: user.DiscordId}).First(user) + return nil +} + func (user *User) FillUserByOidcId() error { if user.OidcId == "" { return errors.New("oidc id 为空!") @@ -490,6 +499,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsDiscordIdAlreadyTaken(discordId string) bool { + return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1 +} + func IsOidcIdAlreadyTaken(oidcId string) bool { return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 } diff --git a/router/api-router.go b/router/api-router.go index bc49803a2e..0cc35c0238 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -28,6 +28,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) + apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth) apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth) apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) diff --git a/web/src/App.js b/web/src/App.js index 2d715767d1..b2f3ef25bf 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -169,6 +169,14 @@ function App() { } /> + } key={location.pathname}> + + + } + /> { let navigate = useNavigate(); @@ -52,6 +54,7 @@ const LoginForm = () => { const [showEmailLogin, setShowEmailLogin] = useState(false); const [wechatLoading, setWechatLoading] = useState(false); const [githubLoading, setGithubLoading] = useState(false); + const [discordLoading, setDiscordLoading] = useState(false); const [oidcLoading, setOidcLoading] = useState(false); const [linuxdoLoading, setLinuxdoLoading] = useState(false); const [emailLoginLoading, setEmailLoginLoading] = useState(false); @@ -215,6 +218,17 @@ const LoginForm = () => { } }; + // 包装的GitHub登录点击处理 + const handleDiscordClick = () => { + setDiscordLoading(true); + try { + onDiscordOAuthClicked(status.discord_client_id); + } finally { + // 由于重定向,这里不会执行到,但为了完整性添加 + setTimeout(() => setDiscordLoading(false), 3000); + } + }; + // 包装的OIDC登录点击处理 const handleOIDCClick = () => { setOidcLoading(true); @@ -304,6 +318,20 @@ const LoginForm = () => { )} + {status.discord_oauth && ( + + )} + {status.oidc_enabled && ( )} + {status.discord_oauth && ( + + )} + {status.oidc_enabled && ( + + + {/* OIDC绑定 */} { GitHubOAuthEnabled: '', GitHubClientId: '', GitHubClientSecret: '', + DiscordOAuthEnabled: '', + DiscordClientId: '', + DiscordClientSecret: '', 'oidc.enabled': '', 'oidc.client_id': '', 'oidc.client_secret': '', @@ -105,6 +108,7 @@ const SystemSetting = () => { case 'EmailAliasRestrictionEnabled': case 'SMTPSSLEnabled': case 'LinuxDOOAuthEnabled': + case 'DiscordOAuthEnabled': case 'oidc.enabled': case 'WorkerAllowHttpImageRequestEnabled': item.value = toBoolean(item.value); @@ -334,6 +338,28 @@ const SystemSetting = () => { } }; + + const submitDiscordOAuth = async () => { + const options = []; + + if (originInputs['DiscordClientId'] !== inputs.DiscordClientId) { + options.push({ key: 'DiscordClientId', value: inputs.DiscordClientId }); + } + if ( + originInputs['DiscordClientSecret'] !== inputs.DiscordClientSecret && + inputs.DiscordClientSecret !== '' + ) { + options.push({ + key: 'DiscordClientSecret', + value: inputs.DiscordClientSecret, + }); + } + + if (options.length > 0) { + await updateOptions(options); + } + }; + const submitOIDCSettings = async () => { if (inputs['oidc.well_known'] && inputs['oidc.well_known'] !== '') { if ( @@ -616,6 +642,15 @@ const SystemSetting = () => { > {t('允许通过 GitHub 账户登录 & 注册')} + + handleCheckboxChange('DiscordOAuthEnabled', e) + } + > + {t('允许通过 Discord 进行登录 & 注册')} + { + + + {t('用以支持通过 Discord 进行登录注册')} + + + + + + + + + + + + diff --git a/web/src/helpers/api.js b/web/src/helpers/api.js index cad1dd134b..9cc061fd03 100644 --- a/web/src/helpers/api.js +++ b/web/src/helpers/api.js @@ -193,6 +193,17 @@ export async function getOAuthState() { } } +export async function onDiscordOAuthClicked(client_id) { + const state = await getOAuthState(); + if (!state) return; + const redirect_uri = `${window.location.origin}/oauth/discord`; + const response_type = 'code'; + const scope = 'identify+openid'; + window.open( + `https://discord.com/oauth2/authorize?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`, + ); +} + export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) { const state = await getOAuthState(); if (!state) return; diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 1ff11e1f67..fdb4823448 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -8,6 +8,7 @@ "注 册": "Sign Up", "使用 邮箱或用户名 登录": "Sign in with Email or Username", "使用 GitHub 继续": "Continue with GitHub", + "使用 Discord 继续": "Continue with Discord", "使用 OIDC 继续": "Continue with OIDC", "使用 微信 继续": "Continue with WeChat", "使用 LinuxDO 继续": "Continue with LinuxDO", @@ -1498,6 +1499,7 @@ "用户分组和额度管理": "User Group and Quota Management", "绑定信息": "Binding Information", "第三方账户绑定状态(只读)": "Third-party account binding status (read-only)", + "已绑定的 Discord 账户": "Bound Discord accounts", "已绑定的 OIDC 账户": "Bound OIDC accounts", "使用兑换码充值余额": "Recharge balance with redemption code", "支持多种支付方式": "Support multiple payment methods", diff --git a/web/src/pages/User/EditUser.js b/web/src/pages/User/EditUser.js index bfccf37b8e..da42d8fbf6 100644 --- a/web/src/pages/User/EditUser.js +++ b/web/src/pages/User/EditUser.js @@ -52,6 +52,7 @@ const EditUser = (props) => { display_name: '', password: '', github_id: '', + discord_id: '', oidc_id: '', wechat_id: '', telegram_id: '', @@ -288,7 +289,7 @@ const EditUser = (props) => { - {['github_id', 'oidc_id', 'wechat_id', 'email', 'telegram_id'].map((field) => ( + {['github_id', 'discord_id', 'oidc_id', 'wechat_id', 'email', 'telegram_id'].map((field) => (