Skip to content

Commit

Permalink
Merge pull request #26 from SafeNet-2024/feature/add-security
Browse files Browse the repository at this point in the history
[feat] ์›น์†Œ์ผ“ ์—ฐ๊ฒฐ ๋ฐ ๋ฉ”์‹œ์ง€ ์ „์†ก์‹œ JWT ํ† ํฐ ๊ฒ€์ฆ ๊ณผ์ • ์ถ”๊ฐ€
  • Loading branch information
Yeon-chae authored Jun 11, 2024
2 parents a31f1fb + 49e74a5 commit 0d3f2d9
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.SafeNet.Backend.domain.message.dto.MessageDto;
import com.SafeNet.Backend.domain.message.service.MessageService;
import com.SafeNet.Backend.domain.messageroom.service.MessageRoomService;
import com.SafeNet.Backend.global.auth.JwtTokenProvider;
import com.SafeNet.Backend.global.pubsub.RedisPublisher;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
Expand All @@ -13,58 +14,42 @@
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; // ๋กœ๊น…์„ ์œ„ํ•ด ์ถ”๊ฐ€
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestHeader;

import java.time.LocalDateTime;
import java.util.List;

@Slf4j
@RequiredArgsConstructor
@Controller
@Tag(name = "Message", description = "Message API")
public class MessageController {
@Autowired
private JwtTokenProvider jwtTokenProvider;

private final RedisPublisher redisPublisher;
private final MessageRoomService messageRoomService;
private final MessageService messageService;

// websocket "/pub/chat/message"๋กœ ๋“ค์–ด์˜ค๋Š” ๋ฉ”์‹œ์ง•์„ ์ฒ˜๋ฆฌํ•œ๋‹ค.

// websocket "/pub/chat/message"๋กœ ๋“ค์–ด์˜ค๋Š” ๋ฉ”์‹œ์ง•์„ ์ฒ˜๋ฆฌํ•œ๋‹ค.
@MessageMapping("/chat/message")
@Operation(summary = "๋ฉ”์‹œ์ง€ ๋ฐœ์†ก", description = "WebSocket์„ ํ†ตํ•ด ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐœ์†กํ•œ๋‹ค")
public void message(@RequestHeader(name = "ACCESS_TOKEN", required = false) String accessToken,
@RequestHeader(name = "REFRESH_TOKEN", required = false) String refreshToken,
MessageDto messageDto) {

// ๋กœ๊น… ์ถ”๊ฐ€
log.info("Received message:");
log.info(" - sender: {}", messageDto.getSender());
log.info(" - roomId: {}", messageDto.getRoomId());
log.info(" - message: {}", messageDto.getMessage());
log.info(" - sentTime: {}", messageDto.getSentTime());
log.info("ACCESS_TOKEN: {}", accessToken);
log.info("REFRESH_TOKEN: {}", refreshToken);

// ํด๋ผ์ด์–ธํŠธ ์ฑ„ํŒ…๋ฐฉ(topic) ์ž…์žฅ, ๋Œ€ํ™”๋ฅผ ์œ„ํ•ด ๋ฆฌ์Šค๋„ˆ์™€ ์—ฐ๋™
messageRoomService.enterMessageRoom(messageDto.getRoomId());

// MessageDto ๊ฐ์ฒด๋ฅผ ๋นŒ๋”๋ฅผ ํ†ตํ•ด ์ƒ์„ฑํ•˜๊ณ  ํ˜„์žฌ ์‹œ๊ฐ„์„ ์„ค์ •
MessageDto messageWithTime = MessageDto.builder()
.sender(messageDto.getSender())
.roomId(messageDto.getRoomId())
.message(messageDto.getMessage())
.sentTime(LocalDateTime.now().toString())
.build();

// Websocket์— ๋ฐœํ–‰๋œ ๋ฉ”์‹œ์ง€๋ฅผ redis๋กœ ๋ฐœํ–‰ํ•œ๋‹ค(publish)
// ํ•ด๋‹น ์ชฝ์ง€๋ฐฉ์„ ๊ตฌ๋…(subscribe)ํ•œ ํด๋ผ์ด์–ธํŠธ์—๊ฒŒ ๋ฉ”์‹œ์ง€๊ฐ€ ์‹ค์‹œ๊ฐ„ ์ „์†ก
redisPublisher.publish(messageRoomService.getTopic(messageDto.getRoomId()), messageWithTime);
// Access Token ๊ฒ€์ฆ
if (accessToken == null || !jwtTokenProvider.validateToken(accessToken)) { // ๋ฉ”์‹œ์ง€ ์ „์†ก ์ „ ์œ ํšจํ•œ ํ† ํฐ์ธ์ง€ ๊ฒ€์ฆ
throw new AccessDeniedException("Invalid or expired token");
}

// DB์™€ Redis์— ๋ฉ”์‹œ์ง€ ์ €์žฅ
messageService.saveMessage(messageDto);
// ๋ฉ”์‹œ์ง€ ์ „์†ก ๋กœ์ง ํ˜ธ์ถœ
messageRoomService.handleMessage(messageDto.getRoomId(), messageDto.getSender(), messageDto);
}

// ๋Œ€ํ™” ๋‚ด์—ญ ์กฐํšŒ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ public MessageService(
public void saveMessage(MessageDto messageDto) {
MessageRoom messageRoom = messageRoomRepository.findByRoomId(messageDto.getRoomId())
.orElseThrow(() -> new IllegalArgumentException("ํ•ด๋‹น ์ชฝ์ง€๋ฐฉ์ด ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค."));

if (!messageRoom.isFirstMessageSent()) {
messageRoom.setFirstMessageSent(true);
messageRoomRepository.save(messageRoom); // ์ฒซ ๋ฉ”์‹œ์ง€ ์—ฌ๋ถ€๋ฅผ true๋กœ ์—…๋ฐ์ดํŠธ
}

Message message = Message.builder()
.sender(messageDto.getSender())
.messageRoom(messageRoom)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public class MessageRoom {
@Column(nullable = false)
private String receiver; // ์ฑ„ํŒ…๋ฐฉ ์ˆ˜์‹ ์ž

@Column(nullable = false)
private boolean firstMessageSent = false; // ๋ณด๋‚ธ ๋ฉ”์‹œ์ง€๊ฐ€ ์กด์žฌํ•˜๋Š”์ง€

@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
private LocalDateTime createdAt; // ํ˜„์žฌ ์‹œ๊ฐ„ ์ž๋™ ํ• ๋‹น
Expand All @@ -50,4 +53,8 @@ public class MessageRoom {
@ManyToOne(fetch = LAZY)
@JoinColumn(name = "post_id", nullable = false)
private Post post;

public void setFirstMessageSent(boolean firstMessageSent) {
this.firstMessageSent = firstMessageSent;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

import com.SafeNet.Backend.domain.member.entity.Member;
import com.SafeNet.Backend.domain.member.repository.MemberRepository;
import com.SafeNet.Backend.domain.message.dto.MessageDto;
import com.SafeNet.Backend.domain.message.entity.Message;
import com.SafeNet.Backend.domain.message.dto.MessageResponseDto;
import com.SafeNet.Backend.domain.message.repository.MessageRepository;
import com.SafeNet.Backend.domain.message.service.MessageService;
import com.SafeNet.Backend.domain.messageroom.entity.MessageRoom;
import com.SafeNet.Backend.domain.messageroom.dto.MessageRoomDto;
import com.SafeNet.Backend.domain.messageroom.repository.MessageRoomRepository;
import com.SafeNet.Backend.domain.post.entity.Post;
import com.SafeNet.Backend.domain.post.exception.PostException;
import com.SafeNet.Backend.domain.post.repository.PostRepository;
import com.SafeNet.Backend.global.exception.CustomException;
import com.SafeNet.Backend.global.pubsub.RedisPublisher;
import com.SafeNet.Backend.global.pubsub.RedisSubscriber;
import jakarta.annotation.PostConstruct;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -23,6 +26,7 @@
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;

import java.time.LocalDateTime;
import java.util.*;

@Slf4j
Expand All @@ -38,6 +42,9 @@ public class MessageRoomService {

private HashOperations<String, String, MessageRoomDto> opsHashMessageRoom;
private Map<String, ChannelTopic> topics;
private final RedisPublisher redisPublisher; // RedisPublisher ์ฃผ์ž…

private final MessageService messageService;

private static final String Message_Rooms = "MESSAGE_ROOM"; // Redis์— ์ฑ„ํŒ…๋ฐฉ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•œ ํ•ด์‹œ๋งต์˜ ํ‚ค

Expand All @@ -48,14 +55,18 @@ public MessageRoomService(
MemberRepository memberRepository,
RedisMessageListenerContainer redisMessageListener,
RedisSubscriber redisSubscriber,
@Qualifier("customRedisTemplate") RedisTemplate<String, Object> redisTemplate) {
@Qualifier("customRedisTemplate") RedisTemplate<String, Object> redisTemplate,
RedisPublisher redisPublisher,
MessageService messageService) {
this.messageRoomRepository = messageRoomRepository;
this.messageRepository = messageRepository;
this.postRepository = postRepository;
this.memberRepository = memberRepository;
this.redisMessageListener = redisMessageListener;
this.redisSubscriber = redisSubscriber;
this.redisTemplate = redisTemplate;
this.redisPublisher = redisPublisher;
this.messageService = messageService;
}

@PostConstruct
Expand Down Expand Up @@ -127,15 +138,14 @@ public List<MessageResponseDto> findAllRoomByUser(String email) {
List<MessageResponseDto> messageRoomDtos = new ArrayList<>();

for (MessageRoom messageRoom : messageRooms) {
// member๊ฐ€ sender์ธ ๊ฒฝ์šฐ
if (member.getName().equals(messageRoom.getSender())) {
if (member.getName().equals(messageRoom.getSender())) { // member๊ฐ€ sender์ธ ๊ฒฝ์šฐ
// ๊ฐ€์žฅ ์ตœ์‹  ๋ฉ”์‹œ์ง€ & ์ƒ์„ฑ ์‹œ๊ฐ„ ์กฐํšŒ
// TimeStamped ํด๋ž˜์Šค์—์„œ ์„ค์ •ํ•ด๋‘” ๋ณด๋‚ธ ์‹œ๊ฐ„(sentTime)์„ ํ†ตํ•ด ๊ฐ ์ฑ„ํŒ…๋ฐฉ์—์„œ ๊ฐ€์žฅ ์ตœ๊ทผ ๋ฉ”์‹œ์ง€์™€ ๊ทธ ๋ฉ”์‹œ์ง€๊ฐ€ ๋ณด๋‚ด์ง„ ์‹œ๊ฐ„์„ ๊บผ๋‚ธ๋‹ค.
Message latestMessage = messageRepository.findTopByMessageRoom_RoomIdOrderBySentTimeDesc(messageRoom.getRoomId());
MessageResponseDto messageRoomDto = setLatestMessage(messageRoom, latestMessage, messageRoom.getReceiver());
messageRoomDtos.add(messageRoomDto);

} else { // user๊ฐ€ receiver์ธ ๊ฒฝ์šฐ
} else if (member.getName().equals(messageRoom.getReceiver()) || messageRoom.isFirstMessageSent()) { // user๊ฐ€ receiver์ด๋ฉด์„œ ์ฒซ๋ฒˆ์งธ ๋ฉ”์‹œ์ง€๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
// ๊ฐ€์žฅ ์ตœ์‹  ๋ฉ”์‹œ์ง€ & ์ƒ์„ฑ ์‹œ๊ฐ„ ์กฐํšŒ
Message latestMessage = messageRepository.findTopByMessageRoom_RoomIdOrderBySentTimeDesc(messageRoom.getRoomId());
MessageResponseDto messageRoomDto = setLatestMessage(messageRoom, latestMessage, messageRoom.getSender());
Expand Down Expand Up @@ -267,18 +277,47 @@ private MessageRoomDto convertToDto(MessageRoom messageRoom) {
}

// ์ชฝ์ง€๋ฐฉ ์ž…์žฅ
public void enterMessageRoom(String roomId) {
public void enterMessageRoom(String roomId, String username) {
ChannelTopic topic = topics.get(roomId);

if (topic == null) {
topic = new ChannelTopic(roomId);
redisMessageListener.addMessageListener(redisSubscriber, topic);
topics.put(roomId, topic);

// ์ž…์žฅ ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€ (์ฑ„ํŒ…๋ฐฉ์— ์ฒ˜์Œ ์ž…์žฅํ•˜๋Š” ๊ฒฝ์šฐ)
MessageDto enterMessage = MessageDto.builder()
.sender(username) // ๋ฉ”์‹œ์ง€ ๋ณด๋‚ธ ์‚ฌ๋žŒ์œผ๋กœ username ์„ค์ •
.roomId(roomId)
.message(username + "์ด(๊ฐ€) ์ฑ„ํŒ…๋ฐฉ์— ์ž…์žฅํ–ˆ์Šต๋‹ˆ๋‹ค")
.sentTime(LocalDateTime.now().toString())
.build();
redisPublisher.publish(topic, enterMessage);
}
}

// redis ์ฑ„๋„์—์„œ ์ชฝ์ง€๋ฐฉ ์กฐํšŒ
public ChannelTopic getTopic(String roomId) {
return topics.get(roomId);
}

public void handleMessage(String roomId, String username, MessageDto messageDto) {
// ํด๋ผ์ด์–ธํŠธ ์ฑ„ํŒ…๋ฐฉ(topic) ์ž…์žฅ, ๋Œ€ํ™”๋ฅผ ์œ„ํ•ด ๋ฆฌ์Šค๋„ˆ์™€ ์—ฐ๋™
enterMessageRoom(roomId, username);

// ๋ฉ”์‹œ์ง€ ์ „์†ก ๋กœ์ง
MessageDto messageWithTime = MessageDto.builder()
.sender(messageDto.getSender())
.roomId(messageDto.getRoomId())
.message(messageDto.getMessage())
.sentTime(LocalDateTime.now().toString())
.build();

// Websocket์— ๋ฐœํ–‰๋œ ๋ฉ”์‹œ์ง€๋ฅผ redis๋กœ ๋ฐœํ–‰ํ•œ๋‹ค(publish)
// ํ•ด๋‹น ์ชฝ์ง€๋ฐฉ์„ ๊ตฌ๋…(subscribe)ํ•œ ํด๋ผ์ด์–ธํŠธ์—๊ฒŒ ๋ฉ”์‹œ์ง€๊ฐ€ ์‹ค์‹œ๊ฐ„ ์ „์†ก
redisPublisher.publish(getTopic(messageDto.getRoomId()), messageWithTime);

// DB์™€ Redis์— ๋ฉ”์‹œ์ง€ ์ €์žฅ
messageService.saveMessage(messageDto);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.SafeNet.Backend.global.config;

import com.SafeNet.Backend.global.auth.JwtTokenProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.stereotype.Component;

import java.util.Collections;

// WebSocket ๋ฉ”์‹œ์ง€์˜ ํ—ค๋”์—์„œ ACCESS_TOKEN์„ ์ถ”์ถœํ•˜๊ณ  ๊ฒ€์ฆ
// ์œ ํšจํ•œ ํ† ํฐ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์‚ฌ์šฉ์ž ์ธ์ฆ ์ •๋ณด๋ฅผ ์„ค์ •ํ•˜๊ณ , ์œ ํšจํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ ์—ฐ๊ฒฐ์„ ์ฐจ๋‹จ
@Component
public class AuthChannelInterceptor implements ChannelInterceptor {

private final JwtTokenProvider jwtTokenProvider;

@Autowired
public AuthChannelInterceptor(JwtTokenProvider jwtTokenProvider) {
this.jwtTokenProvider = jwtTokenProvider;
}

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);
if (StompCommand.CONNECT.equals(accessor.getCommand())) { // CONNECT ํ”„๋ ˆ์ž„์€ ์„œ๋ฒ„์— ๋Œ€ํ•œ ์ธ์ฆ ๋ฐ ๊ธฐํƒ€ ์„ค์ •๊ณผ ๊ด€๋ จ๋œ ์ •๋ณด๋ฅผ ์ „์†กํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ
String token = accessor.getFirstNativeHeader("ACCESS_TOKEN");
if (token != null && token.startsWith("Bearer ")) {
token = token.substring(7);
if (jwtTokenProvider.validateToken(token)) {
String username = jwtTokenProvider.getAuthentication(token).getName();
accessor.setUser(new UsernamePasswordAuthenticationToken(username, null, Collections.emptyList()));
} else {
throw new IllegalArgumentException("Invalid or expired token");
}
} else {
throw new IllegalArgumentException("Missing or invalid ACCESS_TOKEN header");
}
}
return message;
}
}
11 changes: 11 additions & 0 deletions src/main/java/com/SafeNet/Backend/global/config/WebSockConfig.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.SafeNet.Backend.global.config;

import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
Expand All @@ -9,6 +10,16 @@
@Configuration
@EnableWebSocketMessageBroker
public class WebSockConfig implements WebSocketMessageBrokerConfigurer {
private final AuthChannelInterceptor authChannelInterceptor;

public WebSockConfig(AuthChannelInterceptor authChannelInterceptor) {
this.authChannelInterceptor = authChannelInterceptor;
}

@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(authChannelInterceptor); // ์›น์†Œ์ผ“ ๊ฒ€์ฆ์„ ์œ„ํ•œ ์ธํ„ฐ์…‰ํ„ฐ ๋“ฑ๋ก
}

@Override
public void configureMessageBroker(MessageBrokerRegistry config) {
Expand Down

0 comments on commit 0d3f2d9

Please sign in to comment.