diff --git a/.env.example b/.env.example index 6c2d0d2..f392ed0 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,5 @@ POSTGRES_DB=p2p_shopping POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres +MONGO_DB=p2p_shopping_mongo JWT_SECRET=your-secret-key-here-at-least-32-characters-long \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f02281a..e816115 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,11 +13,11 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis - - name: Set up JDK 17 + - name: Set up JDK 21 uses: actions/setup-java@v4 with: - java-version: 17 - distribution: 'zulu' # Alternative distribution options are available + java-version: 21 + distribution: 'temurin' - name: Cache SonarQube packages uses: actions/cache@v4 with: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..9634466 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,82 @@ +name: Tests +on: + push: + branches: + - main + - develop + pull_request: + types: [opened, synchronize, reopened] + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + + services: + postgres: + image: postgis/postgis:16-3.4 + env: + POSTGRES_DB: p2p_shopping_test + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U postgres -d p2p_shopping_test" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + mongodb: + image: mongo:7 + env: + MONGO_INITDB_DATABASE: p2p_shopping_test + ports: + - 27017:27017 + options: >- + --health-cmd "mongosh --eval 'db.adminCommand(\"ping\")'" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + java-version: 21 + distribution: 'temurin' + + - name: Cache Gradle packages + uses: actions/cache@v4 + with: + path: | + ~/.gradle/caches + ~/.gradle/wrapper + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*', '**/gradle-wrapper.properties') }} + restore-keys: ${{ runner.os }}-gradle- + + - name: Grant execute permission for gradlew + run: chmod +x gradlew + + - name: Run tests + env: + SPRING_DATA_MONGODB_URI: mongodb://localhost:27017/p2p_shopping_test + run: ./gradlew test --info + + - name: Upload test reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-reports + path: build/reports/tests/ + retention-days: 7 + + - name: Upload coverage report + if: always() + uses: actions/upload-artifact@v4 + with: + name: coverage-reports + path: build/reports/jacoco/ + retention-days: 7 diff --git a/build.gradle.kts b/build.gradle.kts index d2ba82d..1f0aa81 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -47,8 +47,8 @@ dependencies { testImplementation("org.springframework.boot:spring-boot-starter-test") testImplementation("org.junit.platform:junit-platform-suite-api") - testImplementation("org.testcontainers:testcontainers:1.19.0") - testImplementation("org.testcontainers:postgresql:1.19.0") + testImplementation("org.testcontainers:testcontainers") + testImplementation("org.testcontainers:postgresql:1.21.4") testRuntimeOnly("org.junit.platform:junit-platform-suite-engine") testRuntimeOnly("org.junit.platform:junit-platform-launcher") } diff --git a/docker-compose.yml b/docker-compose.yml index a42ee5b..6a2f68f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,5 +18,22 @@ services: timeout: 5s retries: 5 + mongodb: + image: mongo:7 + container_name: p2p_shopping_mongodb + restart: unless-stopped + environment: + MONGO_INITDB_DATABASE: ${MONGO_DB:-p2p_shopping_mongo} + ports: + - "27017:27017" + volumes: + - mongodata:/data/db + healthcheck: + test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"] + interval: 10s + timeout: 5s + retries: 5 + volumes: - pgdata: \ No newline at end of file + pgdata: + mongodata: \ No newline at end of file diff --git a/src/main/java/com/p2ps/auth/security/JwtAuthFilter.java b/src/main/java/com/p2ps/auth/security/JwtAuthFilter.java index f3ba0de..73c7727 100644 --- a/src/main/java/com/p2ps/auth/security/JwtAuthFilter.java +++ b/src/main/java/com/p2ps/auth/security/JwtAuthFilter.java @@ -22,41 +22,55 @@ public JwtAuthFilter(JwtUtil jwtUtil) { this.jwtUtil = jwtUtil; } - @Override - public void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { + public String extractBearerToken(String authorizationHeader) { + if (authorizationHeader == null || authorizationHeader.isBlank()) { + return null; + } + + String token = authorizationHeader.trim(); - String token = null; - String userEmail = null; + if (token.regionMatches(true, 0, "Bearer ", 0, 7)) { + return token.substring(7).trim(); + } + + return token; + } - // Extract token from Authorization header - String authorizationHeader = request.getHeader("Authorization"); - if (authorizationHeader != null && authorizationHeader.startsWith("Bearer ")) { - token = authorizationHeader.substring(7); + public UsernamePasswordAuthenticationToken authenticateToken(String token) { + if (token == null || token.isBlank()) { + return null; } try { - if (token != null) { - userEmail = jwtUtil.extractEmail(token); + String userEmail = jwtUtil.extractEmail(token); + + if (userEmail != null && !jwtUtil.isTokenExpired(token)) { + return new UsernamePasswordAuthenticationToken(userEmail, null, new ArrayList<>()); } + } catch (Exception _) { + return null; + } - // Authenticate if user is not already in the SecurityContext - if (userEmail != null && SecurityContextHolder.getContext().getAuthentication() == null && !jwtUtil.isTokenExpired(token)){ - UsernamePasswordAuthenticationToken authToken = new UsernamePasswordAuthenticationToken( - userEmail, null, new ArrayList<>() - ); - authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); + return null; + } - // Set user as authenticated - SecurityContextHolder.getContext().setAuthentication(authToken); - } + @Override + public void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + String token = extractBearerToken(request.getHeader("Authorization")); + + try { + UsernamePasswordAuthenticationToken authToken = authenticateToken(token); + + if (authToken != null && SecurityContextHolder.getContext().getAuthentication() == null) { + authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); + SecurityContextHolder.getContext().setAuthentication(authToken); + } } catch (Exception _) { - // Clear context if token is expired, malformed, or invalid SecurityContextHolder.clearContext(); } - // Continue the filter chain filterChain.doFilter(request, response); } -} \ No newline at end of file +} diff --git a/src/main/java/com/p2ps/auth/security/SecurityConfig.java b/src/main/java/com/p2ps/auth/security/SecurityConfig.java index 033ac6b..4e97205 100644 --- a/src/main/java/com/p2ps/auth/security/SecurityConfig.java +++ b/src/main/java/com/p2ps/auth/security/SecurityConfig.java @@ -57,6 +57,8 @@ public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Excepti ) .authorizeHttpRequests(auth -> auth .requestMatchers("/api/auth/**").permitAll() + .requestMatchers("/ws/**").permitAll() + .requestMatchers("/").permitAll() .anyRequest().authenticated() ) .sessionManagement(sess -> sess.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) diff --git a/src/main/java/com/p2ps/config/JwtHandshakeInterceptor.java b/src/main/java/com/p2ps/config/JwtHandshakeInterceptor.java new file mode 100644 index 0000000..88a05eb --- /dev/null +++ b/src/main/java/com/p2ps/config/JwtHandshakeInterceptor.java @@ -0,0 +1,64 @@ +package com.p2ps.config; + +import com.p2ps.auth.security.JwtAuthFilter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.stereotype.Component; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.util.UriComponentsBuilder; + +import java.util.Map; + +@Component +public class JwtHandshakeInterceptor implements HandshakeInterceptor { + + public static final String SESSION_TOKEN_ATTRIBUTE = "wsJwtToken"; + + private static final Logger logger = LoggerFactory.getLogger(JwtHandshakeInterceptor.class); + + private final JwtAuthFilter jwtAuthFilter; + private final boolean enableUrlToken; + + public JwtHandshakeInterceptor(JwtAuthFilter jwtAuthFilter, + @Value("${websocket.compatibility.enableUrlToken:false}") boolean enableUrlToken) { + this.jwtAuthFilter = jwtAuthFilter; + this.enableUrlToken = enableUrlToken; + } + + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, + Map attributes) { + if (!enableUrlToken) { + return true; + } + + String token = UriComponentsBuilder.fromUri(request.getURI()) + .build() + .getQueryParams() + .getFirst("token"); + + if (token == null || token.isBlank()) { + return true; + } + + if (jwtAuthFilter.authenticateToken(token) == null) { + logger.warn("Rejecting websocket handshake with invalid JWT query token"); + response.setStatusCode(HttpStatus.UNAUTHORIZED); + return false; + } + + attributes.put(SESSION_TOKEN_ATTRIBUTE, token); + return true; + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, + Exception exception) { + // No-op. + } +} diff --git a/src/main/java/com/p2ps/config/RoomSubscriptionInterceptor.java b/src/main/java/com/p2ps/config/RoomSubscriptionInterceptor.java index dec28be..e61bdbe 100644 --- a/src/main/java/com/p2ps/config/RoomSubscriptionInterceptor.java +++ b/src/main/java/com/p2ps/config/RoomSubscriptionInterceptor.java @@ -2,7 +2,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.lang.Nullable; +import org.jspecify.annotations.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.simp.stomp.StompCommand; @@ -10,7 +10,9 @@ import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.stereotype.Component; +import org.springframework.security.core.Authentication; +import java.security.Principal; import java.util.regex.Pattern; /** @@ -41,9 +43,14 @@ public Message preSend(Message message, MessageChannel channel) { String destination = accessor.getDestination(); if (destination != null && destination.startsWith("/topic/list/")) { + Principal principal = accessor.getUser(); + if (!(principal instanceof Authentication authentication) || !authentication.isAuthenticated()) { + logger.warn("Security Alert: Blocked subscription attempt without authenticated principal"); + return null; + } + String listId = destination.substring("/topic/list/".length()); - // Security Check: Only allow alphanumeric list IDs (plus hyphens). Prevents directory traversal or wildcard injection. if (!VALID_LIST_ID.matcher(listId).matches()) { logger.warn("Security Alert: Blocked malformed room subscription attempt"); return null; diff --git a/src/main/java/com/p2ps/config/StompJwtAuthInterceptor.java b/src/main/java/com/p2ps/config/StompJwtAuthInterceptor.java new file mode 100644 index 0000000..03e7310 --- /dev/null +++ b/src/main/java/com/p2ps/config/StompJwtAuthInterceptor.java @@ -0,0 +1,99 @@ +package com.p2ps.config; + +import com.p2ps.auth.security.JwtAuthFilter; +import org.jspecify.annotations.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.stereotype.Component; + +import java.util.Map; + +@Component +public class StompJwtAuthInterceptor implements ChannelInterceptor { + + private final JwtAuthFilter jwtAuthFilter; + + public StompJwtAuthInterceptor(JwtAuthFilter jwtAuthFilter) { + this.jwtAuthFilter = jwtAuthFilter; + } + + @Override + @Nullable + public Message preSend(Message message, MessageChannel channel) { + StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); + + if (StompCommand.CONNECT.equals(accessor.getCommand())) { + UsernamePasswordAuthenticationToken authentication = resolveAuthentication(accessor); + + if (authentication != null) { + accessor.setUser(authentication); + return MessageBuilder.fromMessage(message) + .setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication) + .build(); + } + + return message; + } + + return message; + } + + private UsernamePasswordAuthenticationToken resolveAuthentication(StompHeaderAccessor accessor) { + String token = resolveToken(accessor); + if (token == null) { + return null; + } + + UsernamePasswordAuthenticationToken authentication = jwtAuthFilter.authenticateToken(token); + if (authentication == null) { + throw new BadCredentialsException("Invalid JWT token"); + } + return authentication; + } + + private String resolveToken(StompHeaderAccessor accessor) { + String headerToken = accessor.getFirstNativeHeader("Authorization"); + if (headerToken == null) { + headerToken = accessor.getFirstNativeHeader("authorization"); + } + if (headerToken == null) { + headerToken = accessor.getFirstNativeHeader("token"); + } + if (headerToken == null) { + headerToken = accessor.getFirstNativeHeader("access_token"); + } + + if (headerToken != null) { + return extractBearerToken(headerToken); + } + + Map sessionAttributes = accessor.getSessionAttributes(); + if (sessionAttributes == null) { + return null; + } + + Object sessionToken = sessionAttributes.get(JwtHandshakeInterceptor.SESSION_TOKEN_ATTRIBUTE); + return sessionToken instanceof String string ? string : null; + } + + private String extractBearerToken(String authorizationHeader) { + if (authorizationHeader == null || authorizationHeader.isBlank()) { + return null; + } + + String token = authorizationHeader.trim(); + if (token.regionMatches(true, 0, "Bearer ", 0, 7)) { + return token.substring(7).trim(); + } + + return token; + } + +} diff --git a/src/main/java/com/p2ps/config/WebSocketConfig.java b/src/main/java/com/p2ps/config/WebSocketConfig.java index fffbac5..199e886 100644 --- a/src/main/java/com/p2ps/config/WebSocketConfig.java +++ b/src/main/java/com/p2ps/config/WebSocketConfig.java @@ -20,6 +20,8 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { @Value("${app.cors.allowed-origins}") private String[] allowedOrigins; + private final JwtHandshakeInterceptor jwtHandshakeInterceptor; + private final StompJwtAuthInterceptor stompJwtAuthInterceptor; private final RoomSubscriptionInterceptor subscriptionInterceptor; /** @@ -27,7 +29,11 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { * @param subscriptionInterceptor the interceptor validating inbound traffic */ @Autowired - public WebSocketConfig(RoomSubscriptionInterceptor subscriptionInterceptor) { + public WebSocketConfig(JwtHandshakeInterceptor jwtHandshakeInterceptor, + StompJwtAuthInterceptor stompJwtAuthInterceptor, + RoomSubscriptionInterceptor subscriptionInterceptor) { + this.jwtHandshakeInterceptor = jwtHandshakeInterceptor; + this.stompJwtAuthInterceptor = stompJwtAuthInterceptor; this.subscriptionInterceptor = subscriptionInterceptor; } @@ -40,6 +46,7 @@ public WebSocketConfig(RoomSubscriptionInterceptor subscriptionInterceptor) { @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/ws") + .addInterceptors(jwtHandshakeInterceptor) .setAllowedOriginPatterns(allowedOrigins) .withSockJS(); } @@ -62,6 +69,6 @@ public void configureMessageBroker(MessageBrokerRegistry config) { */ @Override public void configureClientInboundChannel(ChannelRegistration registration) { - registration.interceptors(subscriptionInterceptor); + registration.interceptors(stompJwtAuthInterceptor, subscriptionInterceptor); } -} \ No newline at end of file +} diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index a965e7d..b28eea1 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -5,6 +5,7 @@ spring.datasource.url=jdbc:postgresql://localhost:5433/${POSTGRES_DB} spring.datasource.username=${POSTGRES_USER} spring.datasource.password=${POSTGRES_PASSWORD} spring.datasource.driver-class-name=org.postgresql.Driver +spring.data.mongodb.uri=mongodb://localhost:27017/${MONGO_DB:p2p_shopping_mongo} jwt.secret=${JWT_SECRET:defaultDevSecretKeyThatIsAtLeast32BytesLong!!} spring.jpa.hibernate.ddl-auto=validate diff --git a/src/test/java/com/p2ps/auth/JwtAuthFilterTest.java b/src/test/java/com/p2ps/auth/JwtAuthFilterTest.java deleted file mode 100644 index ed14824..0000000 --- a/src/test/java/com/p2ps/auth/JwtAuthFilterTest.java +++ /dev/null @@ -1,74 +0,0 @@ -package com.p2ps.auth; - - -import com.p2ps.auth.security.JwtAuthFilter; -import com.p2ps.auth.security.JwtUtil; -import jakarta.servlet.FilterChain; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.boot.webmvc.test.autoconfigure.AutoConfigureMockMvc; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.test.context.ActiveProfiles; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.Mockito.*; - - -@SpringBootTest(properties = { - "jwt.secret=test-secret-key-care-trebuie-sa-fie-foar vhjcbfvifdbvishfiuhsiufhsuhfwa4yr78e2hfhdsiuncfjsdbhcsbdzhHbhcsdvsdfsffzvfvsaklmdl$%cjsdnfjnsjfnsjnfesf$^%$^fgjnenzskrgerte-lunga-32-chars", - "spring.datasource.url=jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1", - "spring.datasource.driver-class-name=org.h2.Driver", - "spring.datasource.username=sa", - "spring.datasource.password=", - "spring.jpa.database-platform=org.hibernate.dialect.H2Dialect", - "spring.jpa.hibernate.ddl-auto=create-drop" -}) -@AutoConfigureMockMvc(addFilters = false) -@ActiveProfiles("test") -@ExtendWith(MockitoExtension.class) -class JwtAuthFilterTest { - - @Mock - private JwtUtil jwtUtil; - - @Mock - private FilterChain filterChain; - - @InjectMocks - private JwtAuthFilter jwtAuthFilter; - - @Test - void doFilterInternal_WithValidTokenInHeader() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - - request.addHeader("Authorization", "Bearer valid-jwt"); - - when(jwtUtil.extractEmail("valid-jwt")).thenReturn("test@test.com"); - when(jwtUtil.isTokenExpired("valid-jwt")).thenReturn(false); - - jwtAuthFilter.doFilterInternal(request, response, filterChain); - - assertNotNull(SecurityContextHolder.getContext().getAuthentication()); - assertEquals("test@test.com", SecurityContextHolder.getContext().getAuthentication().getPrincipal()); - verify(filterChain).doFilter(request, response); - } - - @Test - void doFilterInternal_NoHeader_ShouldContinueChain() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - SecurityContextHolder.clearContext(); - - jwtAuthFilter.doFilterInternal(request, response, filterChain); - - assertNull(SecurityContextHolder.getContext().getAuthentication()); - verify(filterChain).doFilter(request, response); - } -} \ No newline at end of file diff --git a/src/test/java/com/p2ps/auth/security/JwtAuthFilterTest.java b/src/test/java/com/p2ps/auth/security/JwtAuthFilterTest.java new file mode 100644 index 0000000..8bfa4fe --- /dev/null +++ b/src/test/java/com/p2ps/auth/security/JwtAuthFilterTest.java @@ -0,0 +1,233 @@ +package com.p2ps.auth.security; + + +import jakarta.servlet.FilterChain; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.context.SecurityContextHolder; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + + +@ExtendWith(MockitoExtension.class) +class JwtAuthFilterTest { + + @Mock + private JwtUtil jwtUtil; + + @Mock + private FilterChain filterChain; + + @InjectMocks + private JwtAuthFilter jwtAuthFilter; + + @BeforeEach + void setUp() { + SecurityContextHolder.clearContext(); + } + + @AfterEach + void tearDown() { + SecurityContextHolder.clearContext(); + } + + @Nested + class ExtractBearerToken { + + @Test + void nullInput_ReturnsNull() { + assertNull(jwtAuthFilter.extractBearerToken(null)); + } + + @Test + void blankInput_ReturnsNull() { + assertNull(jwtAuthFilter.extractBearerToken("")); + assertNull(jwtAuthFilter.extractBearerToken(" ")); + } + + @Test + void bearerPrefix_ReturnsToken() { + assertEquals("xyz", jwtAuthFilter.extractBearerToken("Bearer xyz")); + } + + @Test + void lowercaseBearer_ReturnsToken() { + assertEquals("xyz", jwtAuthFilter.extractBearerToken("bearer xyz")); + } + + @Test + void uppercaseBearer_ReturnsToken() { + assertEquals("xyz", jwtAuthFilter.extractBearerToken("BEARER xyz")); + } + + @Test + void bareToken_ReturnsToken() { + assertEquals("xyz", jwtAuthFilter.extractBearerToken("xyz")); + } + + @Test + void bearerWithOnlyWhitespace_ReturnsBearer() { + assertEquals("Bearer", jwtAuthFilter.extractBearerToken("Bearer ")); + } + + @Test + void bearerWithExtraSpaces_Trimmed() { + assertEquals("xyz", jwtAuthFilter.extractBearerToken(" Bearer xyz ")); + } + } + + @Nested + class AuthenticateToken { + + @Test + void nullInput_ReturnsNull() { + assertNull(jwtAuthFilter.authenticateToken(null)); + } + + @Test + void blankInput_ReturnsNull() { + assertNull(jwtAuthFilter.authenticateToken("")); + assertNull(jwtAuthFilter.authenticateToken(" ")); + } + + @Test + void validToken_ReturnsAuthentication() { + when(jwtUtil.extractEmail("valid-token")).thenReturn("test@test.com"); + when(jwtUtil.isTokenExpired("valid-token")).thenReturn(false); + + UsernamePasswordAuthenticationToken result = jwtAuthFilter.authenticateToken("valid-token"); + + assertNotNull(result); + assertEquals("test@test.com", result.getPrincipal()); + } + + @Test + void nullEmail_ReturnsNull() { + when(jwtUtil.extractEmail("token")).thenReturn(null); + + assertNull(jwtAuthFilter.authenticateToken("token")); + verify(jwtUtil, never()).isTokenExpired(any()); + } + + @Test + void expiredToken_ReturnsNull() { + when(jwtUtil.extractEmail("expired-token")).thenReturn("test@test.com"); + when(jwtUtil.isTokenExpired("expired-token")).thenReturn(true); + + assertNull(jwtAuthFilter.authenticateToken("expired-token")); + } + + @Test + void exceptionDuringExtraction_ReturnsNull() { + when(jwtUtil.extractEmail("bad-token")).thenThrow(new RuntimeException("parse error")); + + assertNull(jwtAuthFilter.authenticateToken("bad-token")); + } + } + + @Nested + class DoFilterInternal { + + @Test + void withValidTokenInHeader() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + request.addHeader("Authorization", "Bearer valid-jwt"); + + when(jwtUtil.extractEmail("valid-jwt")).thenReturn("test@test.com"); + when(jwtUtil.isTokenExpired("valid-jwt")).thenReturn(false); + + jwtAuthFilter.doFilterInternal(request, response, filterChain); + + assertNotNull(SecurityContextHolder.getContext().getAuthentication()); + assertEquals("test@test.com", SecurityContextHolder.getContext().getAuthentication().getPrincipal()); + verify(filterChain).doFilter(request, response); + } + + @Test + void noHeader_ShouldContinueChain() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + jwtAuthFilter.doFilterInternal(request, response, filterChain); + + assertNull(SecurityContextHolder.getContext().getAuthentication()); + verify(filterChain).doFilter(request, response); + } + + @Test + void expiredToken_ContextNotSet() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + request.addHeader("Authorization", "Bearer expired-jwt"); + + when(jwtUtil.extractEmail("expired-jwt")).thenReturn("test@test.com"); + when(jwtUtil.isTokenExpired("expired-jwt")).thenReturn(true); + + jwtAuthFilter.doFilterInternal(request, response, filterChain); + + assertNull(SecurityContextHolder.getContext().getAuthentication()); + verify(filterChain).doFilter(request, response); + } + + @Test + void alreadyAuthenticated_ContextNotOverwritten() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + UsernamePasswordAuthenticationToken existing = new UsernamePasswordAuthenticationToken("existing", null); + SecurityContextHolder.getContext().setAuthentication(existing); + + request.addHeader("Authorization", "Bearer new-jwt"); + when(jwtUtil.extractEmail("new-jwt")).thenReturn("new@test.com"); + when(jwtUtil.isTokenExpired("new-jwt")).thenReturn(false); + + jwtAuthFilter.doFilterInternal(request, response, filterChain); + + assertEquals("existing", SecurityContextHolder.getContext().getAuthentication().getPrincipal()); + verify(filterChain).doFilter(request, response); + } + + @Test + void exceptionDuringAuth_ContextClearedAndContinues() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + request.addHeader("Authorization", "Bearer bad-jwt"); + when(jwtUtil.extractEmail("bad-jwt")).thenThrow(new RuntimeException("error")); + + jwtAuthFilter.doFilterInternal(request, response, filterChain); + + assertNull(SecurityContextHolder.getContext().getAuthentication()); + verify(filterChain).doFilter(request, response); + } + + @Test + void bareTokenWithoutBearerPrefix() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + request.addHeader("Authorization", "raw-token-value"); + + when(jwtUtil.extractEmail("raw-token-value")).thenReturn("test@test.com"); + when(jwtUtil.isTokenExpired("raw-token-value")).thenReturn(false); + + jwtAuthFilter.doFilterInternal(request, response, filterChain); + + assertNotNull(SecurityContextHolder.getContext().getAuthentication()); + assertEquals("test@test.com", SecurityContextHolder.getContext().getAuthentication().getPrincipal()); + } + } +} diff --git a/src/test/java/com/p2ps/auth/security/SecurityConfigTest.java b/src/test/java/com/p2ps/auth/security/SecurityConfigTest.java new file mode 100644 index 0000000..c44cbe1 --- /dev/null +++ b/src/test/java/com/p2ps/auth/security/SecurityConfigTest.java @@ -0,0 +1,75 @@ +package com.p2ps.auth.security; + +import com.p2ps.auth.service.UserService; +import org.junit.jupiter.api.Test; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.ProviderManager; +import org.springframework.security.authentication.dao.DaoAuthenticationProvider; +import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; +import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.CorsConfigurationSource; +import org.springframework.web.cors.UrlBasedCorsConfigurationSource; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +class SecurityConfigTest { + + @Test + void passwordEncoder_ReturnsBCryptPasswordEncoder() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + SecurityConfig config = new SecurityConfig(jwtAuthFilter); + + PasswordEncoder encoder = config.passwordEncoder(); + + assertNotNull(encoder); + assertInstanceOf(BCryptPasswordEncoder.class, encoder); + } + + @Test + void authenticationManager_ReturnsProviderManagerWithDaoProvider() throws Exception { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + SecurityConfig config = new SecurityConfig(jwtAuthFilter); + PasswordEncoder passwordEncoder = config.passwordEncoder(); + UserService userService = mock(UserService.class); + + AuthenticationManager manager = config.authenticationManager(mock(org.springframework.security.config.annotation.web.builders.HttpSecurity.class), passwordEncoder, userService); + + assertNotNull(manager); + assertInstanceOf(ProviderManager.class, manager); + + ProviderManager providerManager = (ProviderManager) manager; + assertEquals(1, providerManager.getProviders().size()); + assertInstanceOf(DaoAuthenticationProvider.class, providerManager.getProviders().get(0)); + } + + @Test + void corsConfigurationSource_ConfiguredCorrectly() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + SecurityConfig config = new SecurityConfig(jwtAuthFilter); + + CorsConfigurationSource source = config.corsConfigurationSource(); + + assertNotNull(source); + assertInstanceOf(UrlBasedCorsConfigurationSource.class, source); + + CorsConfiguration corsConfig = ((UrlBasedCorsConfigurationSource) source).getCorsConfiguration(new MockHttpServletRequest("/api/test")); + + assertNotNull(corsConfig); + assertEquals(List.of("http://localhost:5173"), corsConfig.getAllowedOrigins()); + assertEquals(List.of("GET", "POST", "PUT", "DELETE", "OPTIONS"), corsConfig.getAllowedMethods()); + assertEquals(List.of("Authorization", "Content-Type", "Accept"), corsConfig.getAllowedHeaders()); + assertTrue(corsConfig.getAllowCredentials()); + } + + private static class MockHttpServletRequest extends org.springframework.mock.web.MockHttpServletRequest { + public MockHttpServletRequest(String pathInfo) { + setRequestURI(pathInfo); + setServletPath(pathInfo); + } + } +} diff --git a/src/test/java/com/p2ps/config/JwtHandshakeInterceptorTest.java b/src/test/java/com/p2ps/config/JwtHandshakeInterceptorTest.java new file mode 100644 index 0000000..ced4769 --- /dev/null +++ b/src/test/java/com/p2ps/config/JwtHandshakeInterceptorTest.java @@ -0,0 +1,123 @@ +package com.p2ps.config; + +import com.p2ps.auth.security.JwtAuthFilter; +import org.junit.jupiter.api.Test; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketHandler; + +import java.net.URI; +import java.util.HashMap; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +class JwtHandshakeInterceptorTest { + + @Test + void beforeHandshake_AllowsValidQueryToken() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + JwtHandshakeInterceptor interceptor = new JwtHandshakeInterceptor(jwtAuthFilter, true); + ServerHttpRequest request = mock(ServerHttpRequest.class); + ServerHttpResponse response = mock(ServerHttpResponse.class); + WebSocketHandler handler = mock(WebSocketHandler.class); + HashMap attributes = new HashMap<>(); + + when(request.getURI()).thenReturn(URI.create("https://example.com/ws?token=valid-token")); + + when(jwtAuthFilter.authenticateToken("valid-token")).thenReturn(new org.springframework.security.authentication.UsernamePasswordAuthenticationToken("user@test.com", null, java.util.List.of())); + + boolean allowed = interceptor.beforeHandshake(request, response, handler, attributes); + + assertTrue(allowed); + assertEquals("valid-token", attributes.get(JwtHandshakeInterceptor.SESSION_TOKEN_ATTRIBUTE)); + verify(response, never()).setStatusCode(any()); + } + + @Test + void beforeHandshake_RejectsInvalidQueryToken() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + JwtHandshakeInterceptor interceptor = new JwtHandshakeInterceptor(jwtAuthFilter, true); + ServerHttpRequest request = mock(ServerHttpRequest.class); + ServerHttpResponse response = mock(ServerHttpResponse.class); + WebSocketHandler handler = mock(WebSocketHandler.class); + HashMap attributes = new HashMap<>(); + + when(request.getURI()).thenReturn(URI.create("https://example.com/ws?token=bad-token")); + + when(jwtAuthFilter.authenticateToken("bad-token")).thenReturn(null); + + boolean allowed = interceptor.beforeHandshake(request, response, handler, attributes); + + assertFalse(allowed); + verify(response).setStatusCode(HttpStatus.UNAUTHORIZED); + } + + @Test + void beforeHandshake_SkipsUrlTokenWhenFlagDisabled() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + JwtHandshakeInterceptor interceptor = new JwtHandshakeInterceptor(jwtAuthFilter, false); + ServerHttpRequest request = mock(ServerHttpRequest.class); + ServerHttpResponse response = mock(ServerHttpResponse.class); + WebSocketHandler handler = mock(WebSocketHandler.class); + HashMap attributes = new HashMap<>(); + + when(request.getURI()).thenReturn(URI.create("https://example.com/ws?token=some-token")); + + boolean allowed = interceptor.beforeHandshake(request, response, handler, attributes); + + assertTrue(allowed); + assertNull(attributes.get(JwtHandshakeInterceptor.SESSION_TOKEN_ATTRIBUTE)); + verify(jwtAuthFilter, never()).authenticateToken(any()); + verify(response, never()).setStatusCode(any()); + } + + @Test + void beforeHandshake_NullTokenWithFlagEnabled_Allows() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + JwtHandshakeInterceptor interceptor = new JwtHandshakeInterceptor(jwtAuthFilter, true); + ServerHttpRequest request = mock(ServerHttpRequest.class); + ServerHttpResponse response = mock(ServerHttpResponse.class); + WebSocketHandler handler = mock(WebSocketHandler.class); + HashMap attributes = new HashMap<>(); + + when(request.getURI()).thenReturn(URI.create("https://example.com/ws")); + + boolean allowed = interceptor.beforeHandshake(request, response, handler, attributes); + + assertTrue(allowed); + assertNull(attributes.get(JwtHandshakeInterceptor.SESSION_TOKEN_ATTRIBUTE)); + verify(jwtAuthFilter, never()).authenticateToken(any()); + } + + @Test + void beforeHandshake_BlankTokenWithFlagEnabled_Allows() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + JwtHandshakeInterceptor interceptor = new JwtHandshakeInterceptor(jwtAuthFilter, true); + ServerHttpRequest request = mock(ServerHttpRequest.class); + ServerHttpResponse response = mock(ServerHttpResponse.class); + WebSocketHandler handler = mock(WebSocketHandler.class); + HashMap attributes = new HashMap<>(); + + when(request.getURI()).thenReturn(URI.create("https://example.com/ws?token=")); + + boolean allowed = interceptor.beforeHandshake(request, response, handler, attributes); + + assertTrue(allowed); + assertNull(attributes.get(JwtHandshakeInterceptor.SESSION_TOKEN_ATTRIBUTE)); + verify(jwtAuthFilter, never()).authenticateToken(any()); + } + + @Test + void afterHandshake_NoOp() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + JwtHandshakeInterceptor interceptor = new JwtHandshakeInterceptor(jwtAuthFilter, true); + ServerHttpRequest request = mock(ServerHttpRequest.class); + ServerHttpResponse response = mock(ServerHttpResponse.class); + WebSocketHandler handler = mock(WebSocketHandler.class); + + assertDoesNotThrow(() -> interceptor.afterHandshake(request, response, handler, new RuntimeException("test"))); + assertDoesNotThrow(() -> interceptor.afterHandshake(request, response, handler, null)); + } +} diff --git a/src/test/java/com/p2ps/config/RoomSubscriptionInterceptorTest.java b/src/test/java/com/p2ps/config/RoomSubscriptionInterceptorTest.java index c207c42..58c6a4f 100644 --- a/src/test/java/com/p2ps/config/RoomSubscriptionInterceptorTest.java +++ b/src/test/java/com/p2ps/config/RoomSubscriptionInterceptorTest.java @@ -9,6 +9,7 @@ import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import java.util.stream.Stream; @@ -20,16 +21,28 @@ class RoomSubscriptionInterceptorTest { private final RoomSubscriptionInterceptor interceptor = new RoomSubscriptionInterceptor(); private Message createMessage(StompCommand command, String destination) { + return createMessage(command, destination, null); + } + + private Message createMessage(StompCommand command, String destination, UsernamePasswordAuthenticationToken user) { StompHeaderAccessor accessor = StompHeaderAccessor.create(command); if (destination != null) { accessor.setDestination(destination); } + if (user != null) { + accessor.setUser(user); + } + accessor.setLeaveMutable(true); return MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); } @Test void preSend_ValidSubscription() { - Message message = createMessage(StompCommand.SUBSCRIBE, "/topic/list/valid-ID-123"); + Message message = createMessage( + StompCommand.SUBSCRIBE, + "/topic/list/valid-ID-123", + new UsernamePasswordAuthenticationToken("test@test.com", null, java.util.List.of()) + ); MessageChannel channel = mock(MessageChannel.class); Message result = interceptor.preSend(message, channel); @@ -39,7 +52,11 @@ void preSend_ValidSubscription() { @Test void preSend_InvalidSubscription_ReturnsNull() { - Message message = createMessage(StompCommand.SUBSCRIBE, "/topic/list/invalid_ID!"); + Message message = createMessage( + StompCommand.SUBSCRIBE, + "/topic/list/invalid_ID!", + new UsernamePasswordAuthenticationToken("test@test.com", null, java.util.List.of()) + ); MessageChannel channel = mock(MessageChannel.class); Message result = interceptor.preSend(message, channel); @@ -50,7 +67,9 @@ void preSend_InvalidSubscription_ReturnsNull() { @ParameterizedTest(name = "command={0}, destination={1}") @MethodSource("nonBlockingSubscriptions") void preSend_NonBlockingSubscriptions_Pass(StompCommand command, String destination) { - Message message = createMessage(command, destination); + Message message = command == StompCommand.SUBSCRIBE + ? createMessage(command, destination, new UsernamePasswordAuthenticationToken("test@test.com", null, java.util.List.of())) + : createMessage(command, destination); MessageChannel channel = mock(MessageChannel.class); Message result = interceptor.preSend(message, channel); @@ -65,5 +84,14 @@ static Stream nonBlockingSubscriptions() { Arguments.of(StompCommand.SUBSCRIBE, "/topic/other/invalid_ID!") ); } -} + @Test + void preSend_SubscribeWithoutPrincipal_ReturnsNull() { + Message message = createMessage(StompCommand.SUBSCRIBE, "/topic/list/valid-ID-123"); + MessageChannel channel = mock(MessageChannel.class); + + Message result = interceptor.preSend(message, channel); + + assertNull(result); + } +} diff --git a/src/test/java/com/p2ps/config/StompJwtAuthInterceptorTest.java b/src/test/java/com/p2ps/config/StompJwtAuthInterceptorTest.java new file mode 100644 index 0000000..139de30 --- /dev/null +++ b/src/test/java/com/p2ps/config/StompJwtAuthInterceptorTest.java @@ -0,0 +1,220 @@ +package com.p2ps.config; + +import com.p2ps.auth.security.JwtAuthFilter; +import org.junit.jupiter.api.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; + +import java.util.List; +import java.util.HashMap; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +class StompJwtAuthInterceptorTest { + + @Test + void preSend_ConnectWithHeaderTokenAuthenticates() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.addNativeHeader("Authorization", "Bearer valid-token"); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>()); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("user@test.com", null, List.of()); + + when(jwtAuthFilter.authenticateToken(anyString())).thenAnswer(invocation -> { + assertEquals("valid-token", invocation.getArgument(0)); + return authentication; + }); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertSame(authentication, StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_ConnectWithoutTokenPassesThrough() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>()); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + + when(jwtAuthFilter.authenticateToken(null)).thenReturn(null); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertNull(StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_ConnectUsesSessionTokenWhenPresent() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>(java.util.Map.of(JwtHandshakeInterceptor.SESSION_TOKEN_ATTRIBUTE, "query-token"))); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("user@test.com", null, List.of()); + + when(jwtAuthFilter.authenticateToken(anyString())).thenAnswer(invocation -> { + assertEquals("query-token", invocation.getArgument(0)); + return authentication; + }); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertSame(authentication, StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_NonConnectCommandsPassThrough() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + + for (StompCommand command : new StompCommand[]{StompCommand.SUBSCRIBE, StompCommand.SEND, StompCommand.DISCONNECT, StompCommand.UNSUBSCRIBE}) { + StompHeaderAccessor accessor = StompHeaderAccessor.create(command); + accessor.setLeaveMutable(true); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + + Message result = interceptor.preSend(message, channel); + + assertSame(message, result, "Expected same message for " + command); + } + } + + @Test + void preSend_ConnectWithLowercaseAuthorizationHeader() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.addNativeHeader("authorization", "Bearer lower-token"); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>()); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("user@test.com", null, List.of()); + + when(jwtAuthFilter.authenticateToken(anyString())).thenAnswer(invocation -> { + assertEquals("lower-token", invocation.getArgument(0)); + return authentication; + }); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertSame(authentication, StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_ConnectWithTokenHeader() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.addNativeHeader("token", "token-header-value"); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>()); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("user@test.com", null, List.of()); + + when(jwtAuthFilter.authenticateToken(anyString())).thenAnswer(invocation -> { + assertEquals("token-header-value", invocation.getArgument(0)); + return authentication; + }); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertSame(authentication, StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_ConnectWithAccessTokenHeader() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.addNativeHeader("access_token", "access-token-value"); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>()); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("user@test.com", null, List.of()); + + when(jwtAuthFilter.authenticateToken(anyString())).thenAnswer(invocation -> { + assertEquals("access-token-value", invocation.getArgument(0)); + return authentication; + }); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertSame(authentication, StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_ConnectWithNullSessionAttributes() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.setLeaveMutable(true); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + + when(jwtAuthFilter.authenticateToken(null)).thenReturn(null); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertNull(StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_ConnectWithNonStringSessionToken() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>(java.util.Map.of(JwtHandshakeInterceptor.SESSION_TOKEN_ATTRIBUTE, 12345))); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + + when(jwtAuthFilter.authenticateToken(null)).thenReturn(null); + + Message result = interceptor.preSend(message, channel); + + assertNotNull(result); + assertNull(StompHeaderAccessor.wrap(result).getUser()); + } + + @Test + void preSend_ConnectWithInvalidTokenThrowsException() { + JwtAuthFilter jwtAuthFilter = mock(JwtAuthFilter.class); + StompJwtAuthInterceptor interceptor = new StompJwtAuthInterceptor(jwtAuthFilter); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); + accessor.addNativeHeader("Authorization", "Bearer expired-token"); + accessor.setLeaveMutable(true); + accessor.setSessionAttributes(new HashMap<>()); + Message message = MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + MessageChannel channel = mock(MessageChannel.class); + + when(jwtAuthFilter.authenticateToken(anyString())).thenReturn(null); + + assertThrows(BadCredentialsException.class, () -> interceptor.preSend(message, channel)); + } +} diff --git a/src/test/java/com/p2ps/config/WebSocketConfigTest.java b/src/test/java/com/p2ps/config/WebSocketConfigTest.java index ce70d85..72367de 100644 --- a/src/test/java/com/p2ps/config/WebSocketConfigTest.java +++ b/src/test/java/com/p2ps/config/WebSocketConfigTest.java @@ -15,13 +15,17 @@ class WebSocketConfigTest { + private JwtHandshakeInterceptor handshakeInterceptor; + private StompJwtAuthInterceptor stompInterceptor; private RoomSubscriptionInterceptor interceptor; private WebSocketConfig config; @BeforeEach void setUp() { + handshakeInterceptor = mock(JwtHandshakeInterceptor.class); + stompInterceptor = mock(StompJwtAuthInterceptor.class); interceptor = mock(RoomSubscriptionInterceptor.class); - config = new WebSocketConfig(interceptor); + config = new WebSocketConfig(handshakeInterceptor, stompInterceptor, interceptor); } @Test @@ -35,11 +39,13 @@ void registerStompEndpoints() throws Exception { StompWebSocketEndpointRegistration reg = mock(StompWebSocketEndpointRegistration.class); when(registry.addEndpoint(anyString())).thenReturn(reg); + when(reg.addInterceptors(any())).thenReturn(reg); when(reg.setAllowedOriginPatterns(any())).thenReturn(reg); config.registerStompEndpoints(registry); verify(registry).addEndpoint("/ws"); + verify(reg).addInterceptors(handshakeInterceptor); verify(reg).setAllowedOriginPatterns(allowedOrigins); verify(reg).withSockJS(); } @@ -60,7 +66,6 @@ void configureClientInboundChannel() { config.configureClientInboundChannel(registration); - verify(registration).interceptors(interceptor); + verify(registration).interceptors(stompInterceptor, interceptor); } } - diff --git a/src/test/java/resources/application-test.properties b/src/test/java/resources/application-test.properties index f7c4e57..f424c28 100644 --- a/src/test/java/resources/application-test.properties +++ b/src/test/java/resources/application-test.properties @@ -11,6 +11,7 @@ spring.jpa.hibernate.ddl-auto=create-drop spring.jpa.show-sql=false spring.jpa.properties.hibernate.dialect=org.hibernate.dialect.H2Dialect -spring.data.mongodb.uri=mongodb://mock-address:27017/test +# MongoDB - can be overridden via SPRING_DATA_MONGODB_URI env var for integration tests with real MongoDB +spring.data.mongodb.uri=${SPRING_DATA_MONGODB_URI:mongodb://mock-address:27017/test} server.port=8081 \ No newline at end of file