Skip to content

Commit

Permalink
Add: check for required_ports before starting a script
Browse files Browse the repository at this point in the history
  • Loading branch information
nichtsfrei committed Jul 16, 2024
1 parent f4f8477 commit 661d652
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 13 deletions.
11 changes: 10 additions & 1 deletion rust/models/src/port.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl Display for PortRange {
}

/// Enum representing the protocol used for scanning a port.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(
feature = "serde_support",
derive(serde::Serialize, serde::Deserialize)
Expand All @@ -75,6 +75,15 @@ impl TryFrom<&str> for Protocol {
}
}

impl Display for Protocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Protocol::UDP => write!(f, "udp"),
Protocol::TCP => write!(f, "tcp"),
}
}
}

pub fn ports_to_openvas_port_list(ports: Vec<Port>) -> Option<String> {
fn add_range_to_list(list: &mut String, start: usize, end: Option<usize>) {
// Add range
Expand Down
168 changes: 156 additions & 12 deletions rust/nasl-interpreter/src/scan_interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub enum ScriptResultKind {
/// Contains the code provided by exit call or 0 when script finished successful without exit
/// call
ReturnCode(i64),
/// Is missing a port
MissingPort(models::Protocol, String),
/// Script did not run because an excluded key is set
ContainsExcludedKey(String),
/// Script did not run because of missing required keys
Expand Down Expand Up @@ -91,9 +93,13 @@ impl ScriptResult {
ScriptResultKind::MissingRequiredKey(_)
| ScriptResultKind::MissingMandatoryKey(_)
| ScriptResultKind::ContainsExcludedKey(_)
| ScriptResultKind::MissingPort(..)
)
}
}
pub(crate) fn generate_port_kb_key(protocol: models::Protocol, port: &str) -> String {
format!("Ports/{protocol}/{port}")
}

struct ScriptExecutor<'a, T> {
schedule: T,
Expand Down Expand Up @@ -160,7 +166,7 @@ where
fn check_key<A, B, C>(
&self,
key: &storage::ContextKey,
k: &str,
kb_key: &str,
result_none: A,
result_some: B,
result_err: C,
Expand All @@ -170,30 +176,31 @@ where
B: Fn(Primitive) -> Option<ScriptResultKind>,
C: Fn(storage::StorageError) -> Option<ScriptResultKind>,
{
let _span = tracing::error_span!("kb_item", %key, kb_key).entered();
let result = match self
.storage
.retrieve(key, storage::Retrieve::KB(k.to_string()))
.retrieve(key, storage::Retrieve::KB(kb_key.to_string()))
{
Ok(mut x) => {
let x = x.next();
if let Some(x) = x {
match x {
storage::Field::KB(kb) => {
tracing::trace!(key = k, value=?kb.value, "kb found");
tracing::trace!(value=?kb.value, "found");
result_some(kb.value)
}
x => {
tracing::trace!(key = k, field=?x, "found key but it is not a KB item");
tracing::trace!(field=?x, "found but it is not a KB item");
result_none()
}
}
} else {
tracing::trace!(key = k, "kb not found");
tracing::trace!("not found");
result_none()
}
}
Err(e) => {
tracing::warn!(error=%e, key=k, "unable to retrive kb");
tracing::warn!(error=%e, "storage error");
result_err(e)
}
};
Expand All @@ -214,6 +221,10 @@ where
|_| Some(ScriptResultKind::MissingRequiredKey(k.into())),
)
};
for k in &vt.required_keys {
check_required_key(k)?
}

let check_mandatory_key = |k: &str| {
self.check_key(
&key,
Expand All @@ -223,6 +234,10 @@ where
|_| Some(ScriptResultKind::MissingMandatoryKey(k.into())),
)
};
for k in &vt.mandatory_keys {
check_mandatory_key(k)?
}

let check_exclude_key = |k: &str| {
self.check_key(
&key,
Expand All @@ -232,18 +247,38 @@ where
|_| None,
)
};
for k in &vt.required_keys {
check_required_key(k)?
}
for k in &vt.mandatory_keys {
check_mandatory_key(k)?
}
for k in &vt.excluded_keys {
check_exclude_key(k)?
}

use models::Protocol;
let check_port = |pt: Protocol, port: &str| {
let kbk = generate_port_kb_key(pt, port);
self.check_key(
&key,
&kbk,
|| Some(ScriptResultKind::MissingPort(pt, port.to_string())),
|v| {
if v.into() {
None
} else {
Some(ScriptResultKind::MissingPort(pt, port.to_string()))
}
},
|_| Some(ScriptResultKind::MissingPort(pt, port.to_string())),
)
};
for k in &vt.required_ports {
check_port(Protocol::TCP, k)?
}
for k in &vt.required_udp_ports {
check_port(Protocol::UDP, k)?
}

Ok(())
}

// TODO: probably better to enhance ContextKey::Scan to contain target and scan_id?
fn generate_key(&self, target: &str) -> ContextKey {
ContextKey::Scan(format!("{}-{}", self.scan.scan_id, target))
}
Expand Down Expand Up @@ -533,13 +568,19 @@ where
mod tests {
use storage::{Dispatcher, Retriever};





#[derive(Debug, Default)]
struct GenerateScript {
id: String,
rc: usize,
dependencies: Vec<String>,
required_keys: Vec<String>,
mandatory_keys: Vec<String>,
required_tcp_ports: Vec<String>,
required_udp_ports: Vec<String>,
exclude: Vec<String>,
}

Expand Down Expand Up @@ -581,6 +622,27 @@ mod tests {
}
}

fn with_required_ports(id: &str, ports: &[(models::Protocol, &str)]) -> GenerateScript {
let required_tcp_ports = ports
.iter()
.filter(|(p, _)| matches!(p, models::Protocol::TCP))
.map(|(_, p)| p.to_string())
.collect();
let required_udp_ports = ports
.iter()
.filter(|(p, _)| matches!(p, models::Protocol::UDP))
.map(|(_, p)| p.to_string())
.collect();

GenerateScript {
id: id.to_string(),
required_tcp_ports,
required_udp_ports,

..Default::default()
}
}

fn generate(&self) -> (String, storage::item::Nvt) {
let keys = |x: &[String]| -> String {
x.iter().fold(String::default(), |acc, e| {
Expand All @@ -604,6 +666,8 @@ mod tests {
let required = printable("script_require_keys", &self.required_keys);
let dependencies = printable("script_dependencies", &self.dependencies);
let exclude = printable("script_exclude_keys", &self.exclude);
let require_ports = printable("script_require_ports", &self.required_tcp_ports);
let require_udp_ports = printable("script_require_udp_ports", &self.required_udp_ports);

let rc = self.rc;
let id = &self.id;
Expand All @@ -618,6 +682,8 @@ if (description)
{mandatory}
{required}
{exclude}
{require_ports}
{require_udp_ports}
exit(0);
}}
exit({rc});
Expand Down Expand Up @@ -714,6 +780,84 @@ exit({rc});
Ok(results)
}

#[test]
#[tracing_test::traced_test]
fn required_ports() {
let vts = [
GenerateScript::with_required_ports(
"0",
&[
(models::Protocol::UDP, "2000"),
(models::Protocol::TCP, "20"),
],
)
.generate(),
GenerateScript::with_required_ports(
"1",
&[
(models::Protocol::UDP, "2000"),
(models::Protocol::TCP, "2"),
],
)
.generate(),
GenerateScript::with_required_ports(
"2",
&[
(models::Protocol::UDP, "200"),
(models::Protocol::TCP, "20"),
],
)
.generate(),
GenerateScript::with_required_ports(
"3",
&[
(models::Protocol::UDP, "2000"),
(models::Protocol::TCP, "22"),
],
)
.generate(),
GenerateScript::with_required_ports(
"4",
&[
(models::Protocol::UDP, "2002"),
(models::Protocol::TCP, "20"),
],
)
.generate(),
];
let dispatcher = prepare_vt_storage(&vts);
[
(models::Protocol::TCP, "20", 1), // TCP 20 is considered enabled
(models::Protocol::TCP, "22", 0), // TCP 22 is considered disabled
(models::Protocol::UDP, "2000", 1), // UDP 2000 is considered enabled
(models::Protocol::UDP, "2002", 0), // UDP 2002 is considered disabled
]
.into_iter()
.for_each(|(p, port, enabled)| {
dispatcher
.dispatch(
&storage::ContextKey::Scan("sid".into()),
storage::Field::KB((&super::generate_port_kb_key(p, port), enabled).into()),
)
.expect("store kb");
});
let result = run(&vts, dispatcher).expect("success run");
let success = result
.clone()
.into_iter()
.filter_map(|x| x.ok())
.filter(|x| x.has_succeeded())
.collect::<Vec<_>>();
let failure = result
.into_iter()
.filter_map(|x| x.ok())
.filter(|x| x.has_failed())
.filter(|x| x.has_not_run())
.collect::<Vec<_>>();
assert_eq!(success.len(), 1);
assert_eq!(failure.len(), 4);
}

#[test]
#[tracing_test::traced_test]
fn exclude_keys() {
Expand Down
9 changes: 9 additions & 0 deletions rust/storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ pub enum ContextKey {
FileName(String),
}

impl Display for ContextKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContextKey::Scan(id) => write!(f, "scan_id={id}"),
ContextKey::FileName(name) => write!(f, "file={name}"),
}
}
}

impl AsRef<str> for ContextKey {
fn as_ref(&self) -> &str {
match self {
Expand Down

0 comments on commit 661d652

Please sign in to comment.