Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pd: validate PD list #1201

Merged
merged 9 commits into from
Dec 1, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 117 additions & 38 deletions src/pd/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ use std::time::Duration;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::collections::HashSet;
use util::codec::rpc;
use util::make_std_tcp_conn;

use rand::{self, Rng};

use kvproto::pdpb::{Request, Response};
use kvproto::pdpb::{self, Request, Response};
use kvproto::msgpb::{Message, MessageType};

use super::{Result, PdClient};
use super::{Result, protocol};
use super::metrics::*;

const MAX_PD_SEND_RETRY_COUNT: usize = 100;
Expand All @@ -34,16 +35,21 @@ const SOCKET_WRITE_TIMEOUT: u64 = 3;

const PD_RPC_PREFIX: &'static str = "/pd/rpc";

// Only for `validate_endpoints`.
const VALIDATE_MSG_ID: u64 = 0;
const VALIDATE_CLUSTER_ID: u64 = 0;

#[derive(Debug)]
struct RpcClientCore {
endpoints: String,
endpoints: Vec<String>,
stream: Option<TcpStream>,
}

fn send_msg(stream: &mut TcpStream, msg_id: u64, message: &Request) -> Result<(u64, Response)> {
let timer = PD_SEND_MSG_HISTOGRAM.start_timer();

let mut req = Message::new();

req.set_msg_type(MessageType::PdReq);
// TODO: optimize clone later in HTTP refactor.
req.set_pd_req(message.clone());
Expand All @@ -62,42 +68,50 @@ fn send_msg(stream: &mut TcpStream, msg_id: u64, message: &Request) -> Result<(u
Ok((id, resp.take_pd_resp()))
}

fn rpc_connect(endpoints: &str) -> Result<TcpStream> {
// Randomize hosts.
let mut hosts: Vec<String> = endpoints.split(',').map(|s| s.into()).collect();
rand::thread_rng().shuffle(&mut hosts);

for host in &hosts {
let mut stream = match make_std_tcp_conn(host.as_str()) {
Ok(stream) => stream,
Err(_) => continue,
};
try!(stream.set_write_timeout(Some(Duration::from_secs(SOCKET_WRITE_TIMEOUT))));

// Send a HTTP header to tell PD to hijack this connection for RPC.
let header_str = format!("GET {} HTTP/1.0\r\n\r\n", PD_RPC_PREFIX);
let header = header_str.as_bytes();
match stream.write_all(header) {
Ok(_) => return Ok(stream),
Err(_) => continue,
}
}
fn rpc_connect(endpoint: &str) -> Result<TcpStream> {
let mut stream = try!(make_std_tcp_conn(endpoint));
try!(stream.set_write_timeout(Some(Duration::from_secs(SOCKET_WRITE_TIMEOUT))));

Err(box_err!("failed to connect to {:?}", hosts))
// Send a HTTP header to tell PD to hijack this connection for RPC.
let header_str = format!("GET {} HTTP/1.0\r\n\r\n", PD_RPC_PREFIX);
let header = header_str.as_bytes();
match stream.write_all(header) {
Ok(_) => Ok(stream),
Err(err) => Err(box_err!("failed to connect to {} error: {:?}", endpoint, err)),
}
}

impl RpcClientCore {
fn new(endpoints: &str) -> RpcClientCore {
fn new(endpoints: Vec<String>) -> RpcClientCore {
RpcClientCore {
endpoints: endpoints.into(),
endpoints: endpoints,
stream: None,
}
}

fn try_connect(&mut self) -> Result<()> {
let stream = try!(rpc_connect(&self.endpoints));
self.stream = Some(stream);
Ok(())
// Randomize endpoints.
let len = self.endpoints.len();
let mut indexes: Vec<usize> = (0..len).collect();
rand::thread_rng().shuffle(&mut indexes);

for i in indexes {
let ep = &self.endpoints[i];
match rpc_connect(ep.as_str()) {
Ok(stream) => {
info!("PD client connects to {}", ep);
self.stream = Some(stream);
return Ok(());
}

Err(_) => {
error!("failed to connect to {}, try next", ep);
continue;
}
}
}

Err(box_err!("failed to connect to {:?}", self.endpoints))
}

fn send(&mut self, msg_id: u64, req: &Request) -> Result<Response> {
Expand Down Expand Up @@ -147,25 +161,34 @@ pub struct RpcClient {

impl RpcClient {
pub fn new(endpoints: &str) -> Result<RpcClient> {
let mut client = RpcClient {
msg_id: AtomicUsize::new(0),
core: Mutex::new(RpcClientCore::new(endpoints)),
cluster_id: 0,
};
let endpoints: Vec<String> = endpoints.split(',')
.map(|s| s.trim().to_owned())
.filter(|s| !s.is_empty())
.collect();

let mut cluster_id = VALIDATE_CLUSTER_ID;
for _ in 0..MAX_PD_SEND_RETRY_COUNT {
match client.get_cluster_id() {
match Self::validate_endpoints(&endpoints) {
Ok(id) => {
client.cluster_id = id;
return Ok(client);
cluster_id = id;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

break the loop

break;
}
Err(e) => {
warn!("failed to get cluster id from pd: {:?}", e);
thread::sleep(Duration::from_secs(1));
}
}
}
Err(box_err!("failed to get cluster id from pd"))

if cluster_id == VALIDATE_CLUSTER_ID {
return Err(box_err!("failed to get cluster id from pd"));
}

Ok(RpcClient {
msg_id: AtomicUsize::new(0),
core: Mutex::new(RpcClientCore::new(endpoints)),
cluster_id: cluster_id,
})
}

pub fn send(&self, req: &Request) -> Result<Response> {
Expand All @@ -177,4 +200,60 @@ impl RpcClient {
fn alloc_msg_id(&self) -> u64 {
self.msg_id.fetch_add(1, Ordering::Relaxed) as u64
}

/// `validate_endpoints` validates pd members, make sure they are in the same cluster.
/// It returns a cluster ID.
/// Notice that it ignores failed pd nodes.
/// Export for tests.
pub fn validate_endpoints(endpoints: &[String]) -> Result<u64> {
if endpoints.is_empty() {
return Err(box_err!("empty PD endpoints"));
}

let len = endpoints.len();
let mut endpoints_set = HashSet::with_capacity(len);

let mut cluster_id = None;
for ep in endpoints {
if !endpoints_set.insert(ep) {
return Err(box_err!("a duplicate PD url {}", ep));
}

let mut stream = match rpc_connect(ep.as_str()) {
Ok(stream) => stream,
// Ignore failed pd node.
Err(_) => continue,
};

let mut req = protocol::new_request(VALIDATE_CLUSTER_ID,
pdpb::CommandType::GetPDMembers);
req.set_get_pd_members(pdpb::GetPDMembersRequest::new());
let (mid, resp) = match send_msg(&mut stream, VALIDATE_MSG_ID, &req) {
Ok((mid, resp)) => (mid, resp),
// Ignore failed pd node.
Err(_) => continue,
};

if mid != VALIDATE_MSG_ID {
return Err(box_err!("PD response msg_id mismatch, want {}, got {}",
VALIDATE_MSG_ID,
mid));
}

// Check cluster ID.
let cid = resp.get_header().get_cluster_id();
if let Some(sample) = cluster_id {
if sample != cid {
return Err(box_err!("PD response cluster_id mismatch, want {}, got {}",
sample,
cid));
}
} else {
cluster_id = Some(cid);
}
// TODO: check all fields later?
}

cluster_id.ok_or(box_err!("PD cluster stop responding"))
}
}
45 changes: 21 additions & 24 deletions src/pd/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl super::PdClient for RpcClient {
// can send this request with any cluster ID, then PD will return its
// cluster ID in the response header.
let get_pd_members = pdpb::GetPDMembersRequest::new();
let mut req = self.new_request(pdpb::CommandType::GetPDMembers);
let mut req = new_request(self.cluster_id, pdpb::CommandType::GetPDMembers);
req.set_get_pd_members(get_pd_members);

let mut resp = try!(self.send(&req));
Expand All @@ -35,15 +35,15 @@ impl super::PdClient for RpcClient {
bootstrap.set_store(store);
bootstrap.set_region(region);

let mut req = self.new_request(pdpb::CommandType::Bootstrap);
let mut req = new_request(self.cluster_id, pdpb::CommandType::Bootstrap);
req.set_bootstrap(bootstrap);

let resp = try!(self.send(&req));
check_resp(&resp)
}

fn is_cluster_bootstrapped(&self) -> Result<bool> {
let mut req = self.new_request(pdpb::CommandType::IsBootstrapped);
let mut req = new_request(self.cluster_id, pdpb::CommandType::IsBootstrapped);
req.set_is_bootstrapped(pdpb::IsBootstrappedRequest::new());

let resp = try!(self.send(&req));
Expand All @@ -52,7 +52,7 @@ impl super::PdClient for RpcClient {
}

fn alloc_id(&self) -> Result<u64> {
let mut req = self.new_request(pdpb::CommandType::AllocId);
let mut req = new_request(self.cluster_id, pdpb::CommandType::AllocId);
req.set_alloc_id(pdpb::AllocIdRequest::new());

let resp = try!(self.send(&req));
Expand All @@ -64,7 +64,7 @@ impl super::PdClient for RpcClient {
let mut put_store = pdpb::PutStoreRequest::new();
put_store.set_store(store);

let mut req = self.new_request(pdpb::CommandType::PutStore);
let mut req = new_request(self.cluster_id, pdpb::CommandType::PutStore);
req.set_put_store(put_store);

let resp = try!(self.send(&req));
Expand All @@ -75,7 +75,7 @@ impl super::PdClient for RpcClient {
let mut get_store = pdpb::GetStoreRequest::new();
get_store.set_store_id(store_id);

let mut req = self.new_request(pdpb::CommandType::GetStore);
let mut req = new_request(self.cluster_id, pdpb::CommandType::GetStore);
req.set_get_store(get_store);

let mut resp = try!(self.send(&req));
Expand All @@ -84,7 +84,7 @@ impl super::PdClient for RpcClient {
}

fn get_cluster_config(&self) -> Result<metapb::Cluster> {
let mut req = self.new_request(pdpb::CommandType::GetClusterConfig);
let mut req = new_request(self.cluster_id, pdpb::CommandType::GetClusterConfig);
req.set_get_cluster_config(pdpb::GetClusterConfigRequest::new());

let mut resp = try!(self.send(&req));
Expand All @@ -96,7 +96,7 @@ impl super::PdClient for RpcClient {
let mut get_region = pdpb::GetRegionRequest::new();
get_region.set_region_key(key.to_vec());

let mut req = self.new_request(pdpb::CommandType::GetRegion);
let mut req = new_request(self.cluster_id, pdpb::CommandType::GetRegion);
req.set_get_region(get_region);

let mut resp = try!(self.send(&req));
Expand All @@ -108,7 +108,7 @@ impl super::PdClient for RpcClient {
let mut get_region_by_id = pdpb::GetRegionByIDRequest::new();
get_region_by_id.set_region_id(region_id);

let mut req = self.new_request(pdpb::CommandType::GetRegionByID);
let mut req = new_request(self.cluster_id, pdpb::CommandType::GetRegionByID);
req.set_get_region_by_id(get_region_by_id);

let mut resp = try!(self.send(&req));
Expand All @@ -130,7 +130,7 @@ impl super::PdClient for RpcClient {
heartbeat.set_leader(leader);
heartbeat.set_down_peers(RepeatedField::from_vec(down_peers));

let mut req = self.new_request(pdpb::CommandType::RegionHeartbeat);
let mut req = new_request(self.cluster_id, pdpb::CommandType::RegionHeartbeat);
req.set_region_heartbeat(heartbeat);

let mut resp = try!(self.send(&req));
Expand All @@ -142,7 +142,7 @@ impl super::PdClient for RpcClient {
let mut ask_split = pdpb::AskSplitRequest::new();
ask_split.set_region(region);

let mut req = self.new_request(pdpb::CommandType::AskSplit);
let mut req = new_request(self.cluster_id, pdpb::CommandType::AskSplit);
req.set_ask_split(ask_split);

let mut resp = try!(self.send(&req));
Expand All @@ -154,7 +154,7 @@ impl super::PdClient for RpcClient {
let mut heartbeat = pdpb::StoreHeartbeatRequest::new();
heartbeat.set_stats(stats);

let mut req = self.new_request(pdpb::CommandType::StoreHeartbeat);
let mut req = new_request(self.cluster_id, pdpb::CommandType::StoreHeartbeat);
req.set_store_heartbeat(heartbeat);

let resp = try!(self.send(&req));
Expand All @@ -166,27 +166,24 @@ impl super::PdClient for RpcClient {
report_split.set_left(left);
report_split.set_right(right);

let mut req = self.new_request(pdpb::CommandType::ReportSplit);
let mut req = new_request(self.cluster_id, pdpb::CommandType::ReportSplit);
req.set_report_split(report_split);

let resp = try!(self.send(&req));
check_resp(&resp)
}
}

impl RpcClient {
fn new_request(&self, cmd_type: pdpb::CommandType) -> pdpb::Request {
let mut header = pdpb::RequestHeader::new();
header.set_cluster_id(self.cluster_id);
header.set_uuid(Uuid::new_v4().as_bytes().to_vec());
let mut req = pdpb::Request::new();
req.set_header(header);
req.set_cmd_type(cmd_type);
req
}
pub fn new_request(cluster_id: u64, cmd_type: pdpb::CommandType) -> pdpb::Request {
let mut header = pdpb::RequestHeader::new();
header.set_cluster_id(cluster_id);
header.set_uuid(Uuid::new_v4().as_bytes().to_vec());
let mut req = pdpb::Request::new();
req.set_header(header);
req.set_cmd_type(cmd_type);
req
}


fn check_resp(resp: &pdpb::Response) -> Result<()> {
if !resp.has_header() {
return Err(box_err!("pd response missing header"));
Expand Down
15 changes: 15 additions & 0 deletions tests/pd/test_rpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,18 @@ fn test_rpc_client() {
prev_id = alloc_id;
}
}

#[test]
fn test_rpc_client_safely_new() {
let endpoints_1 = match env::var("PD_ENDPOINTS") {
Ok(v) => v,
Err(_) => return,
};
let endpoints_2 = match env::var("PD_ENDPOINTS_SEP") {
Ok(v) => v,
Err(_) => return,
};
let endpoints = [endpoints_1, endpoints_2];

assert!(RpcClient::validate_endpoints(&endpoints).is_err());
}
Loading