This commit is contained in:
parent
acadbf6ae6
commit
c5e02a5fec
|
|
@ -33,6 +33,9 @@ import org.springframework.web.cors.CorsConfiguration;
|
||||||
import org.springframework.web.cors.CorsConfigurationSource;
|
import org.springframework.web.cors.CorsConfigurationSource;
|
||||||
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
|
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
|
||||||
|
|
||||||
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
|
import jakarta.servlet.http.HttpServletResponse;
|
||||||
|
import java.io.IOException;
|
||||||
import java.security.KeyPair;
|
import java.security.KeyPair;
|
||||||
import java.security.KeyPairGenerator;
|
import java.security.KeyPairGenerator;
|
||||||
import java.security.NoSuchAlgorithmException;
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
|
@ -46,7 +49,25 @@ import java.util.UUID;
|
||||||
@Configuration
|
@Configuration
|
||||||
public class SecurityConfig {
|
public class SecurityConfig {
|
||||||
|
|
||||||
|
// 自定义认证入口点,保留client_id参数
|
||||||
|
public static class CustomLoginUrlAuthenticationEntryPoint extends LoginUrlAuthenticationEntryPoint {
|
||||||
|
public CustomLoginUrlAuthenticationEntryPoint(String loginFormUrl) {
|
||||||
|
super(loginFormUrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response, org.springframework.security.core.AuthenticationException exception) {
|
||||||
|
String loginUrl = super.determineUrlToUseForThisRequest(request, response, exception);
|
||||||
|
|
||||||
|
// 获取client_id参数
|
||||||
|
String clientId = request.getParameter("client_id");
|
||||||
|
if (clientId != null && !clientId.isEmpty()) {
|
||||||
|
loginUrl += "?client_id=" + clientId;
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginUrl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 配置授权服务器安全过滤器链
|
// 配置授权服务器安全过滤器链
|
||||||
@Bean
|
@Bean
|
||||||
|
|
@ -58,7 +79,7 @@ public class SecurityConfig {
|
||||||
.oidc(Customizer.withDefaults()); // 启用 OpenID Connect
|
.oidc(Customizer.withDefaults()); // 启用 OpenID Connect
|
||||||
|
|
||||||
http.exceptionHandling(exceptions ->
|
http.exceptionHandling(exceptions ->
|
||||||
exceptions.authenticationEntryPoint(new LoginUrlAuthenticationEntryPoint("/login"))
|
exceptions.authenticationEntryPoint(new CustomLoginUrlAuthenticationEntryPoint("/login"))
|
||||||
);
|
);
|
||||||
|
|
||||||
return http.build();
|
return http.build();
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package com.tuoheng.oauth.oidc.controller;
|
||||||
|
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.web.bind.annotation.GetMapping;
|
import org.springframework.web.bind.annotation.GetMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestParam;
|
||||||
import org.springframework.web.bind.annotation.ResponseBody;
|
import org.springframework.web.bind.annotation.ResponseBody;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
|
|
@ -15,7 +16,7 @@ public class LoginController {
|
||||||
|
|
||||||
@GetMapping("/login")
|
@GetMapping("/login")
|
||||||
@ResponseBody
|
@ResponseBody
|
||||||
public String login(HttpServletRequest request) throws IOException {
|
public String login(HttpServletRequest request, @RequestParam(value = "client_id", required = false) String clientId) throws IOException {
|
||||||
// 读取静态HTML文件
|
// 读取静态HTML文件
|
||||||
String htmlContent = new String(Files.readAllBytes(Paths.get("src/main/resources/static/login.html")));
|
String htmlContent = new String(Files.readAllBytes(Paths.get("src/main/resources/static/login.html")));
|
||||||
|
|
||||||
|
|
@ -27,6 +28,17 @@ public class LoginController {
|
||||||
"id=\"csrf-parameter\" name=\"" + csrfToken.getParameterName() + "\" value=\"" + csrfToken.getToken() + "\"");
|
"id=\"csrf-parameter\" name=\"" + csrfToken.getParameterName() + "\" value=\"" + csrfToken.getToken() + "\"");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果有client_id参数,在页面中显示
|
||||||
|
if (clientId != null && !clientId.isEmpty()) {
|
||||||
|
// 在页面标题中显示client_id
|
||||||
|
htmlContent = htmlContent.replace("<h1>OIDC 登录</h1>",
|
||||||
|
"<h1>OIDC 登录 - " + clientId + "</h1>");
|
||||||
|
|
||||||
|
// 在页面中添加client_id信息
|
||||||
|
htmlContent = htmlContent.replace("<p>请输入您的凭据以继续</p>",
|
||||||
|
"<p>请输入您的凭据以继续</p><p style=\"color: #666; font-size: 12px; margin-top: 5px;\">客户端: " + clientId + "</p>");
|
||||||
|
}
|
||||||
|
|
||||||
return htmlContent;
|
return htmlContent;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Loading…
Reference in New Issue