From 2d49b63472c321849f2261815d9bdc305c36e2c1 Mon Sep 17 00:00:00 2001 From: Jialei Date: Mon, 8 Jan 2024 15:18:47 +0800 Subject: [PATCH] fix(controller): allow recovering project if project is deleted --- .../security/JwtTokenFilter.java | 94 ++++++++++++------- .../security/JwtTokenFilterTest.java | 16 +++- 2 files changed, 73 insertions(+), 37 deletions(-) diff --git a/server/controller/src/main/java/ai/starwhale/mlops/configuration/security/JwtTokenFilter.java b/server/controller/src/main/java/ai/starwhale/mlops/configuration/security/JwtTokenFilter.java index 9046f4635d..d91d09bd1e 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/configuration/security/JwtTokenFilter.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/configuration/security/JwtTokenFilter.java @@ -26,13 +26,12 @@ import ai.starwhale.mlops.domain.user.bo.Role; import ai.starwhale.mlops.domain.user.bo.User; import ai.starwhale.mlops.exception.StarwhaleException; -import ai.starwhale.mlops.exception.SwNotFoundException; -import ai.starwhale.mlops.exception.SwNotFoundException.ResourceType; import ai.starwhale.mlops.exception.SwValidationException; import io.jsonwebtoken.Claims; import java.io.IOException; import java.util.List; import java.util.Set; +import java.util.regex.Pattern; import java.util.stream.Collectors; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -58,40 +57,50 @@ public class JwtTokenFilter extends OncePerRequestFilter { private final List jwtClaimValidators; private static final String AUTH_HEADER = "Authorization"; - - public JwtTokenFilter(JwtTokenUtil jwtTokenUtil, UserService userService, ProjectService projectService, - List jwtClaimValidators) { + private static final List WHITE_LIST_FOR_DELETED_PROJECTS = List.of( + Pattern.compile("/api/v1/project/[^/]+/recover") + ); + + public JwtTokenFilter( + JwtTokenUtil jwtTokenUtil, + UserService userService, + ProjectService projectService, + List jwtClaimValidators + ) { this.jwtTokenUtil = jwtTokenUtil; this.userService = userService; this.projectService = projectService; this.jwtClaimValidators = jwtClaimValidators; } - boolean allowAnonymous(HttpServletRequest request) { - try { - var projects = getProjects(request); - // only for public project - return projects.stream().allMatch(p -> p.getPrivacy() == Project.Privacy.PUBLIC); - } catch (SwNotFoundException e) { - return false; - } + boolean allowAnonymous(Set projects) { + // only for public project + return projects.stream().allMatch(p -> p.getPrivacy() == Project.Privacy.PUBLIC); } @Override - protected void doFilterInternal(HttpServletRequest httpServletRequest, + protected void doFilterInternal( + HttpServletRequest httpServletRequest, @NonNull HttpServletResponse httpServletResponse, - @NonNull FilterChain filterChain) throws ServletException, IOException { + @NonNull FilterChain filterChain + ) throws ServletException, IOException { String header = httpServletRequest.getHeader(AUTH_HEADER); - if (!checkHeader(header)) { + if (isInvalidAuthHeader(header)) { header = httpServletRequest.getParameter(AUTH_HEADER); } - if (!checkHeader(header)) { + + var projects = getProjects(httpServletRequest); + if (!verifyProjectsExist(httpServletRequest, httpServletResponse, projects)) { + return; + } + + if (isInvalidAuthHeader(header)) { // check whether the uri allow anonymous in public project - if (allowAnonymous(httpServletRequest)) { + if (allowAnonymous(projects)) { // Build jwt token with anonymous user JwtLoginToken jwtLoginToken = new JwtLoginToken(null, "", List.of( - Role.builder().roleCode(Role.CODE_ANONYMOUS).roleName(Role.NAME_ANONYMOUS).build())); + Role.builder().roleCode(Role.CODE_ANONYMOUS).roleName(Role.NAME_ANONYMOUS).build())); jwtLoginToken.setDetails(new WebAuthenticationDetails(httpServletRequest)); SecurityContextHolder.getContext().setAuthentication(jwtLoginToken); } else { @@ -123,12 +132,8 @@ protected void doFilterInternal(HttpServletRequest httpServletRequest, role -> role.getAuthority().equals(Role.CODE_OWNER)).collect(Collectors.toSet()); // Get project roles try { - Set projects = getProjects(httpServletRequest); Set rolesOfUser = userService.getProjectsRolesOfUser(user, projects); roles.addAll(rolesOfUser); - } catch (SwNotFoundException e) { - error(httpServletResponse, HttpStatus.NOT_FOUND.value(), Code.validationException, e.getMessage()); - return; } catch (StarwhaleException e) { logger.error(e.getMessage()); } @@ -142,23 +147,40 @@ protected void doFilterInternal(HttpServletRequest httpServletRequest, } @NotNull - private Set getProjects(HttpServletRequest httpServletRequest) throws SwNotFoundException { + private Set getProjects(HttpServletRequest httpServletRequest) { @SuppressWarnings("unchecked") - Set projects = ((Set) httpServletRequest - .getAttribute(ProjectDetectionFilter.ATTRIBUTE_PROJECT)) + var projectIds = (Set) httpServletRequest.getAttribute(ProjectDetectionFilter.ATTRIBUTE_PROJECT); + if (projectIds == null) { + return Set.of(); + } + + return projectIds .stream() - .map((String projectUrl) -> { - var p = projectService.findProject(projectUrl); - if (p.isDeleted()) { - throw new SwNotFoundException(ResourceType.PROJECT, "Project is deleted"); - } - return p; - }) + .map(projectService::findProject) .collect(Collectors.toSet()); - return projects; } - private boolean checkHeader(String header) { - return StringUtils.hasText(header) && header.startsWith("Bearer "); + private boolean isInvalidAuthHeader(String header) { + return !StringUtils.hasText(header) || !header.startsWith("Bearer "); + } + + private boolean verifyProjectsExist(HttpServletRequest request, HttpServletResponse response, Set projects) + throws IOException { + // never check for root path + var uri = request.getRequestURI(); + if (!StringUtils.hasText(uri)) { + return true; + } + if (projects.isEmpty()) { + return true; + } + if (projects.stream().noneMatch(Project::isDeleted)) { + return true; + } + if (WHITE_LIST_FOR_DELETED_PROJECTS.stream().anyMatch(p -> p.matcher(request.getRequestURI()).matches())) { + return true; + } + error(response, HttpStatus.NOT_FOUND.value(), Code.validationException, "Project is deleted"); + return false; } } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/configuration/security/JwtTokenFilterTest.java b/server/controller/src/test/java/ai/starwhale/mlops/configuration/security/JwtTokenFilterTest.java index 1e3368c60c..3b7ab0ed4e 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/configuration/security/JwtTokenFilterTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/configuration/security/JwtTokenFilterTest.java @@ -17,9 +17,12 @@ package ai.starwhale.mlops.configuration.security; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; @@ -129,6 +132,7 @@ public void testDeletedProject() throws ServletException, IOException { HttpServletRequest request = mock(HttpServletRequest.class); when(request.getHeader("Authorization")).thenReturn("Bearer a"); when(request.getAttribute("PROJECT")).thenReturn(Set.of("deleted")); + when(request.getRequestURI()).thenReturn("/api/v1/project/1/jobs"); HttpServletResponse response = mock(HttpServletResponse.class); FilterChain filterchain = mock(FilterChain.class); when(jwtTokenUtil.getUsername(any())).thenReturn("foo"); @@ -137,6 +141,16 @@ public void testDeletedProject() throws ServletException, IOException { jwtTokenFilter.doFilterInternal(request, response, filterchain); httpUtilMockedStatic.verify( () -> HttpUtil.error(response, HttpStatus.NOT_FOUND.value(), Code.validationException, - "Resource is not found Project\nProject is deleted"), times(1)); + "Project is deleted"), times(1)); + + // test project recover + for (var uri : List.of("/api/v1/project/1/recover", "/api/v1/project/abc/recover")) { + when(request.getRequestURI()).thenReturn(uri); + httpUtilMockedStatic.clearInvocations(); + jwtTokenFilter.doFilterInternal(request, response, filterchain); + httpUtilMockedStatic.verify( + () -> HttpUtil.error(any(HttpServletResponse.class), anyInt(), any(Code.class), anyString()), + never()); + } } }