Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.babzip.backend.global.jwt.JwtAuthenticationFilter;
import com.babzip.backend.global.jwt.TokenProvider;
import com.babzip.backend.global.oauth.handler.OAuth2AuthenticationSuccessHandler;
import com.babzip.backend.global.oauth.resolver.CustomAuthorizationRequestResolver;
import com.babzip.backend.global.oauth.service.CustomOAuth2UserService;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
Expand All @@ -17,6 +18,7 @@
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;

Expand All @@ -29,6 +31,7 @@ public class SecurityConfig {
private final CustomOAuth2UserService customOAuth2UserService;
private final OAuth2AuthenticationSuccessHandler oAuth2AuthenticationSuccessHandler;
private final AuthenticationManager authenticationManager;
private final CustomAuthorizationRequestResolver customAuthorizationRequestResolver;

@Bean
public SecurityFilterChain filterChainPermitAll(HttpSecurity http) throws Exception {
Expand All @@ -45,11 +48,15 @@ public HttpSecurity defaultSecurity(HttpSecurity http) throws Exception {
.httpBasic(AbstractHttpConfigurer::disable)
.formLogin(AbstractHttpConfigurer::disable)
.cors(cors -> cors.configurationSource(CorsConfig.corsConfigurationSource()))
.oauth2Login(oauth2 ->
oauth2.userInfoEndpoint(c -> c.userService(customOAuth2UserService))
.successHandler(oAuth2AuthenticationSuccessHandler))
.addFilterAfter(new JwtAuthenticationFilter(authenticationManager), UsernamePasswordAuthenticationFilter.class)
;
.oauth2Login(oauth2 -> oauth2
.authorizationEndpoint(auth ->
auth.authorizationRequestResolver(customAuthorizationRequestResolver))
.userInfoEndpoint(ui ->
ui.userService(customOAuth2UserService))
.successHandler(oAuth2AuthenticationSuccessHandler)
)
.addFilterAfter(new JwtAuthenticationFilter(authenticationManager),
UsernamePasswordAuthenticationFilter.class);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import com.babzip.backend.global.jwt.JwtHandler;
import com.babzip.backend.global.jwt.JwtUserClaim;
import com.babzip.backend.global.oauth.service.OAuth2UserPrincipal;
import com.babzip.backend.global.oauth.util.RedirectUrlValidator;
import com.babzip.backend.global.oauth.util.StateUtil;
import com.babzip.backend.token.entity.Token;
import com.babzip.backend.user.domain.UserRole;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
Expand All @@ -16,6 +20,7 @@
import org.springframework.web.util.UriComponentsBuilder;

import java.io.IOException;
import java.util.Base64;

@Component
@RequiredArgsConstructor
Expand All @@ -28,20 +33,22 @@ public class OAuth2AuthenticationSuccessHandler implements AuthenticationSuccess
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws IOException, ServletException {


String encodedState = request.getParameter("state");
String redirectUri = StateUtil.decode(encodedState);

// 2) 화이트리스트 검증
RedirectUrlValidator.validate(redirectUri);

OAuth2UserPrincipal principal = (OAuth2UserPrincipal) authentication.getPrincipal();
Long userId = principal.getUser().getId();
UserRole role = principal.getUser().getRole();

JwtUserClaim jwtUserClaim = new JwtUserClaim(userId,role);
Token token = jwtHandler.createTokens(jwtUserClaim);

String targetUrl = "http://localhost:5173/auth/success";




// 토큰 붙여서 리다이렉트
String redirectUrl = UriComponentsBuilder.fromUriString(targetUrl)
String redirectUrl = UriComponentsBuilder.fromUriString(redirectUri)
.queryParam("accessToken", token.getAccessToken())
.queryParam("refreshToken", token.getRefreshToken())
.build().toUriString();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.babzip.backend.global.oauth.resolver;

import com.babzip.backend.global.oauth.util.RedirectUrlValidator;
import com.babzip.backend.global.oauth.util.StateUtil;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.stereotype.Component;

@Configuration
public class CustomAuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver {
private final DefaultOAuth2AuthorizationRequestResolver delegate;

public CustomAuthorizationRequestResolver(ClientRegistrationRepository repo) {
this.delegate = new DefaultOAuth2AuthorizationRequestResolver(repo, "/oauth2/authorization");
}

@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
OAuth2AuthorizationRequest original = delegate.resolve(request);
return customizeState(request, original);
}

@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId) {
OAuth2AuthorizationRequest original = delegate.resolve(request, clientRegistrationId);
return customizeState(request, original);
}

private OAuth2AuthorizationRequest customizeState(HttpServletRequest request,
OAuth2AuthorizationRequest original) {
if (original == null) return null;

// 프론트에서 ?redirect_uri=... 로 넘긴 값
String rawRedirect = request.getParameter("redirect_uri");
if (rawRedirect == null || rawRedirect.isBlank()) {
return original; // redirect_uri 없이도 로그인 가능하도록
}

// 화이트리스트 검증
RedirectUrlValidator.validate(rawRedirect);

// 인코딩
String encodedState = StateUtil.encode(rawRedirect);

return OAuth2AuthorizationRequest.from(original)
.state(encodedState)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.babzip.backend.global.oauth.util;

import java.util.Set;

public class RedirectUrlValidator {

private static final Set<String> ALLOWED_PREFIXES = Set.of(
"https://your-app.netlify.app",
"http://localhost:5173"
);

public static void validate(String uri) {
boolean allowed = ALLOWED_PREFIXES.stream().anyMatch(uri::startsWith);
if (!allowed) {
throw new IllegalArgumentException("허용되지 않은 redirectUri: " + uri);
}
}
}
42 changes: 42 additions & 0 deletions src/main/java/com/babzip/backend/global/oauth/util/StateUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.babzip.backend.global.oauth.util;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Map;

public class StateUtil {

private static final ObjectMapper MAPPER = new ObjectMapper();

private StateUtil() {}

/** redirectUri → JSON → Base64URL(without padding) */
public static String encode(String redirectUri) {
try {
String json = MAPPER.writeValueAsString(Map.of("redirectUri", redirectUri));
return Base64.getUrlEncoder()
.withoutPadding()
.encodeToString(json.getBytes(StandardCharsets.UTF_8));
} catch (Exception e) {
throw new IllegalStateException("state encoding 실패", e);
}
}

/** Base64URL → JSON → redirectUri */
public static String decode(String encodedState) {
try {
byte[] bytes = Base64.getUrlDecoder().decode(encodedState);
JsonNode node = MAPPER.readTree(bytes);

if (node.hasNonNull("redirectUri")) {
return node.get("redirectUri").asText();
}
throw new IllegalArgumentException("redirectUri 누락");
} catch (Exception e) {
throw new IllegalArgumentException("state 디코딩 실패", e);
}
}
}