diff --git a/oidc/src/main/java/com/tuoheng/oauth/oidc/config/SecurityConfig.java b/oidc/src/main/java/com/tuoheng/oauth/oidc/config/SecurityConfig.java index 053b662..d88a450 100644 --- a/oidc/src/main/java/com/tuoheng/oauth/oidc/config/SecurityConfig.java +++ b/oidc/src/main/java/com/tuoheng/oauth/oidc/config/SecurityConfig.java @@ -33,6 +33,9 @@ import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; @@ -46,7 +49,25 @@ import java.util.UUID; @Configuration public class SecurityConfig { + // 自定义认证入口点,保留client_id参数 + public static class CustomLoginUrlAuthenticationEntryPoint extends LoginUrlAuthenticationEntryPoint { + public CustomLoginUrlAuthenticationEntryPoint(String loginFormUrl) { + super(loginFormUrl); + } + @Override + protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response, org.springframework.security.core.AuthenticationException exception) { + String loginUrl = super.determineUrlToUseForThisRequest(request, response, exception); + + // 获取client_id参数 + String clientId = request.getParameter("client_id"); + if (clientId != null && !clientId.isEmpty()) { + loginUrl += "?client_id=" + clientId; + } + + return loginUrl; + } + } // 配置授权服务器安全过滤器链 @Bean @@ -58,7 +79,7 @@ public class SecurityConfig { .oidc(Customizer.withDefaults()); // 启用 OpenID Connect http.exceptionHandling(exceptions -> - exceptions.authenticationEntryPoint(new LoginUrlAuthenticationEntryPoint("/login")) + exceptions.authenticationEntryPoint(new CustomLoginUrlAuthenticationEntryPoint("/login")) ); return http.build(); diff --git a/oidc/src/main/java/com/tuoheng/oauth/oidc/controller/LoginController.java b/oidc/src/main/java/com/tuoheng/oauth/oidc/controller/LoginController.java index 0f92908..96d23f2 100644 --- a/oidc/src/main/java/com/tuoheng/oauth/oidc/controller/LoginController.java +++ b/oidc/src/main/java/com/tuoheng/oauth/oidc/controller/LoginController.java @@ -2,6 +2,7 @@ package com.tuoheng.oauth.oidc.controller; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.annotation.RestController; @@ -15,7 +16,7 @@ public class LoginController { @GetMapping("/login") @ResponseBody - public String login(HttpServletRequest request) throws IOException { + public String login(HttpServletRequest request, @RequestParam(value = "client_id", required = false) String clientId) throws IOException { // 读取静态HTML文件 String htmlContent = new String(Files.readAllBytes(Paths.get("src/main/resources/static/login.html"))); @@ -27,6 +28,17 @@ public class LoginController { "id=\"csrf-parameter\" name=\"" + csrfToken.getParameterName() + "\" value=\"" + csrfToken.getToken() + "\""); } + // 如果有client_id参数,在页面中显示 + if (clientId != null && !clientId.isEmpty()) { + // 在页面标题中显示client_id + htmlContent = htmlContent.replace("

OIDC 登录

", + "

OIDC 登录 - " + clientId + "

"); + + // 在页面中添加client_id信息 + htmlContent = htmlContent.replace("

请输入您的凭据以继续

", + "

请输入您的凭据以继续

客户端: " + clientId + "

"); + } + return htmlContent; } } \ No newline at end of file