Skip to content

Commit

Permalink
Merge branch 'master' into feature/add_oceanbase
Browse files Browse the repository at this point in the history
  • Loading branch information
Aias00 authored Aug 2, 2024
2 parents d326fd5 + f6e8e64 commit ba51be7
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hertzbeat.alert.controller;

import org.apache.hertzbeat.alert.service.AlertSilenceService;
import org.apache.hertzbeat.common.constants.CommonConstants;
import org.apache.hertzbeat.common.entity.alerter.AlertSilence;
import org.apache.hertzbeat.common.util.JsonUtil;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.http.MediaType;
import org.springframework.test.web.servlet.MockMvc;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup;

/**
* tes case for {@link AlertSilenceController}
*/

@ExtendWith(MockitoExtension.class)
class AlertSilenceControllerTest {

private MockMvc mockMvc;

@Mock
private AlertSilenceService alertSilenceService;

private AlertSilence alertSilence;

@InjectMocks
private AlertSilenceController alertSilenceController;

@BeforeEach
void setUp() {

this.mockMvc = standaloneSetup(alertSilenceController).build();

alertSilence = AlertSilence.builder()
.id(1L)
.name("Test Silence")
.type((byte) 1)
.build();
}

@Test
void testAddNewAlertSilence() throws Exception {

doNothing().when(alertSilenceService).validate(any(AlertSilence.class), eq(false));
doNothing().when(alertSilenceService).addAlertSilence(any(AlertSilence.class));

mockMvc.perform(post("/api/alert/silence")
.contentType(MediaType.APPLICATION_JSON)
.content(JsonUtil.toJson(alertSilence)))
.andExpect(status().isOk())
.andExpect(jsonPath("$.code").value((int) CommonConstants.SUCCESS_CODE));
}

@Test
void testModifyAlertSilence() throws Exception {

doNothing().when(alertSilenceService).validate(any(AlertSilence.class), eq(true));
doNothing().when(alertSilenceService).modifyAlertSilence(any(AlertSilence.class));

mockMvc.perform(put("/api/alert/silence")
.contentType(MediaType.APPLICATION_JSON)
.content(JsonUtil.toJson(alertSilence)))
.andExpect(status().isOk())
.andExpect(jsonPath("$.code").value((int) CommonConstants.SUCCESS_CODE));
}

@Test
void testGetAlertSilence() throws Exception {

when(alertSilenceService.getAlertSilence(1L)).thenReturn(alertSilence);

mockMvc.perform(get("/api/alert/silence/1")
.accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.data.id").value(1))
.andExpect(jsonPath("$.data.name").value("Test Silence"));
}

@Test
void testGetAlertSilenceNotExists() throws Exception {

when(alertSilenceService.getAlertSilence(1L)).thenReturn(null);

mockMvc.perform(get("/api/alert/silence/1")
.accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.code").value((int) CommonConstants.MONITOR_NOT_EXIST_CODE))
.andExpect(jsonPath("$.msg").value("AlertSilence not exist."));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hertzbeat.collector.collect.common.ssh;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

/**
* Command blacklist
*/
public class CommonSshBlacklist {

private static final Set<String> BLACKLIST;

static {
Set<String> tempSet = new HashSet<>();
initializeDefaultBlacklist(tempSet);
BLACKLIST = Collections.unmodifiableSet(tempSet);
}

private CommonSshBlacklist() {
// Prevent instantiation
}

private static void initializeDefaultBlacklist(Set<String> blacklist) {
// Adding default dangerous commands to blacklist
blacklist.add("rm ");
blacklist.add("mv ");
blacklist.add("cp ");
blacklist.add("ln ");
blacklist.add("dd ");
blacklist.add("tar ");
blacklist.add("zip ");
blacklist.add("bzip2 ");
blacklist.add("bunzip2 ");
blacklist.add("xz ");
blacklist.add("unxz ");
blacklist.add("kill ");
blacklist.add("killall ");
blacklist.add("reboot");
blacklist.add("shutdown");
blacklist.add("poweroff");
blacklist.add("init 0");
blacklist.add("init 6");
blacklist.add("telinit 0");
blacklist.add("telinit 6");
blacklist.add("systemctl halt");
blacklist.add("systemctl suspend");
blacklist.add("systemctl hibernate");
blacklist.add("service reboot");
blacklist.add("service shutdown");
blacklist.add("crontab -e");
blacklist.add("visudo");
blacklist.add("useradd");
blacklist.add("userdel");
blacklist.add("usermod");
blacklist.add("groupadd");
blacklist.add("groupdel");
blacklist.add("groupmod");
blacklist.add("passwd");
blacklist.add("su ");
blacklist.add("sudo ");
blacklist.add("mount ");
blacklist.add("parted");
blacklist.add("mkpart");
blacklist.add("partprobe");
blacklist.add("iptables");
blacklist.add("firewalld");
blacklist.add("nft");
blacklist.add("nc ");
blacklist.add("netcat");
blacklist.add("ssh ");
blacklist.add("scp ");
blacklist.add("rsync");
blacklist.add("ftp ");
blacklist.add("sftp ");
blacklist.add("telnet ");
blacklist.add("chmod ");
blacklist.add("chattr ");
blacklist.add("dd ");
blacklist.add("mknod");
blacklist.add("losetup");
blacklist.add("cryptsetup");
}

public static boolean isCommandBlacklisted(String command) {
if (command == null || command.trim().isEmpty()) {
throw new IllegalArgumentException("Command cannot be null or empty");
}
String trimmedCommand = command.trim();
return BLACKLIST.stream().anyMatch(trimmedCommand::contains);
}

public static Set<String> getBlacklist() {
return BLACKLIST;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.hertzbeat.collector.collect.common.cache.CacheIdentifier;
import org.apache.hertzbeat.collector.collect.common.cache.ConnectionCommonCache;
import org.apache.hertzbeat.collector.collect.common.cache.SshConnect;
import org.apache.hertzbeat.collector.collect.common.ssh.CommonSshBlacklist;
import org.apache.hertzbeat.collector.collect.common.ssh.CommonSshClient;
import org.apache.hertzbeat.collector.dispatch.DispatchConstants;
import org.apache.hertzbeat.collector.util.CollectUtil;
Expand Down Expand Up @@ -85,6 +86,7 @@ public void preCheck(Metrics metrics) throws IllegalArgumentException {

@Override
public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {

long startTime = System.currentTimeMillis();
SshProtocol sshProtocol = metrics.getSsh();
boolean reuseConnection = Boolean.parseBoolean(sshProtocol.getReuseConnection());
Expand All @@ -93,6 +95,12 @@ public void collect(CollectRep.MetricsData.Builder builder, long monitorId, Stri
ClientSession clientSession = null;
try {
clientSession = getConnectSession(sshProtocol, timeout, reuseConnection);
if (CommonSshBlacklist.isCommandBlacklisted(sshProtocol.getScript())) {
builder.setCode(CollectRep.Code.FAIL);
builder.setMsg("The command is blacklisted: " + sshProtocol.getScript());
log.warn("The command is blacklisted: {}", sshProtocol.getScript());
return;
}
channel = clientSession.createExecChannel(sshProtocol.getScript());
ByteArrayOutputStream response = new ByteArrayOutputStream();
channel.setOut(response);
Expand Down

0 comments on commit ba51be7

Please sign in to comment.