Skip to content

Commit

Permalink
Store SIA user agent information in x509 certificate request table (#…
Browse files Browse the repository at this point in the history
…2772)

* extract provider name from SIA user agent and store in x509 certificate request table

---------

Signed-off-by: rajeshal <[email protected]>
Co-authored-by: rajeshal <[email protected]>
  • Loading branch information
rajeshal and rajeshal authored Oct 24, 2024
1 parent e6a62b6 commit 483e0ab
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class DynamoDBCertRecordStoreConnection implements CertRecordStoreConnect
private static final String KEY_TTL = "ttl";
private static final String KEY_REGISTER_TIME = "registerTime";
private static final String KEY_SVC_DATA_UPDATE_TIME = "svcDataUpdateTime";
private static final String KEY_SIA_PROVIDER = "siaProvider";

// the configuration setting is in hours, so we'll automatically
// convert into seconds since that's what dynamoDB needs
Expand Down Expand Up @@ -143,6 +144,7 @@ private X509CertRecord itemToX509CertRecord(Map<String, AttributeValue> item) {
certRecord.setExpiryTime(DynamoDBUtils.getDateFromItem(item, KEY_EXPIRY_TIME));
certRecord.setHostName(DynamoDBUtils.getString(item, KEY_HOSTNAME));
certRecord.setSvcDataUpdateTime(DynamoDBUtils.getDateFromItem(item, KEY_SVC_DATA_UPDATE_TIME));
certRecord.setSiaProvider(DynamoDBUtils.getString(item, KEY_SIA_PROVIDER));
return certRecord;
}

Expand Down Expand Up @@ -186,6 +188,7 @@ public boolean updateX509CertRecord(X509CertRecord certRecord) {
DynamoDBUtils.updateItemLongValue(updatedValues, KEY_SVC_DATA_UPDATE_TIME, certRecord.getSvcDataUpdateTime());
DynamoDBUtils.updateItemLongValue(updatedValues, KEY_EXPIRY_TIME, certRecord.getExpiryTime());
DynamoDBUtils.updateItemStringValue(updatedValues, KEY_HOSTNAME, hostName);
DynamoDBUtils.updateItemStringValue(updatedValues, KEY_SIA_PROVIDER, certRecord.getSiaProvider());

UpdateItemRequest request = UpdateItemRequest.builder()
.tableName(tableName)
Expand Down Expand Up @@ -232,6 +235,7 @@ public boolean insertX509CertRecord(X509CertRecord certRecord) {
itemValues.put(KEY_SVC_DATA_UPDATE_TIME, AttributeValue.fromN(DynamoDBUtils.getNumberFromDate(certRecord.getSvcDataUpdateTime())));
itemValues.put(KEY_REGISTER_TIME, AttributeValue.fromN(String.valueOf(System.currentTimeMillis())));
itemValues.put(KEY_HOSTNAME, AttributeValue.fromS(hostName));
itemValues.put(KEY_SIA_PROVIDER, AttributeValue.fromS(certRecord.getSiaProvider()));

PutItemRequest request = PutItemRequest.builder()
.tableName(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class X509CertRecord {
private Date expiryTime;
private String hostName;
private Date svcDataUpdateTime;
private String siaProvider;

public X509CertRecord() {
}
Expand Down Expand Up @@ -158,6 +159,14 @@ public void setSvcDataUpdateTime(Date svcDataUpdateTime) {
this.svcDataUpdateTime = svcDataUpdateTime;
}

public String getSiaProvider() {
return siaProvider;
}

public void setSiaProvider(String siaProvider) {
this.siaProvider = siaProvider;
}

@Override
public String toString() {
return "X509CertRecord{" +
Expand All @@ -176,6 +185,7 @@ public String toString() {
", expiryTime=" + expiryTime +
", hostName='" + hostName + '\'' +
", svcDataUpdateTime=" + svcDataUpdateTime +
", siaProvider='" + siaProvider + '\'' +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public class JDBCCertRecordStoreConnection implements CertRecordStoreConnection
private static final String SQL_GET_X509_RECORD = "SELECT * FROM certificates WHERE provider=? AND instanceId=? AND service=?;";
private static final String SQL_INSERT_X509_RECORD = "INSERT INTO certificates " +
"(provider, instanceId, service, currentSerial, currentTime, currentIP, prevSerial, prevTime, prevIP, clientCert, " +
"expiryTime, hostName) " +
"VALUES (?,?,?,?,?,?,?,?,?,?,?,?);";
"expiryTime, hostName, siaProvider) " +
"VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?);";
private static final String SQL_UPDATE_X509_RECORD = "UPDATE certificates SET " +
"currentSerial=?, currentTime=?, currentIP=?, prevSerial=?, prevTime=?, prevIP=?, " +
"expiryTime=?, hostName=?, clientCert=? " +
"expiryTime=?, hostName=?, clientCert=?, siaProvider=? " +
"WHERE provider=? AND instanceId=? AND service=?;";
private static final String SQL_DELETE_X509_RECORD = "DELETE from certificates " +
"WHERE provider=? AND instanceId=? AND service=?;";
Expand Down Expand Up @@ -85,7 +85,8 @@ public class JDBCCertRecordStoreConnection implements CertRecordStoreConnection
public static final String DB_COLUMN_LAST_NOTIFIED_SERVER = "lastNotifiedServer";
public static final String DB_COLUMN_EXPIRY_TIME = "expiryTime";
public static final String DB_COLUMN_HOSTNAME = "hostName";

public static final String DB_COLUMN_SIA_PROVIDER = "siaProvider";

Connection con;
int queryTimeout = 10;

Expand Down Expand Up @@ -169,6 +170,7 @@ private X509CertRecord setRecordFromResultSet(ResultSet rs) throws SQLException
certRecord.setLastNotifiedServer(rs.getString(DB_COLUMN_LAST_NOTIFIED_SERVER));
certRecord.setExpiryTime(getDateFromResultSet(rs, DB_COLUMN_EXPIRY_TIME));
certRecord.setHostName(rs.getString(DB_COLUMN_HOSTNAME));
certRecord.setSiaProvider(rs.getString(DB_COLUMN_SIA_PROVIDER));
return certRecord;
}

Expand Down Expand Up @@ -205,9 +207,10 @@ public boolean updateX509CertRecord(X509CertRecord certRecord) throws ServerReso
ps.setTimestamp(7, getTimestampFromDate(certRecord.getExpiryTime()));
ps.setString(8, certRecord.getHostName());
ps.setBoolean(9, certRecord.getClientCert());
ps.setString(10, certRecord.getProvider());
ps.setString(11, certRecord.getInstanceId());
ps.setString(12, certRecord.getService());
ps.setString(10, certRecord.getSiaProvider());
ps.setString(11, certRecord.getProvider());
ps.setString(12, certRecord.getInstanceId());
ps.setString(13, certRecord.getService());
affectedRows = executeUpdate(ps, caller);
} catch (SQLException ex) {
throw sqlError(ex, caller);
Expand All @@ -234,6 +237,7 @@ public boolean insertX509CertRecord(X509CertRecord certRecord) throws ServerReso
ps.setBoolean(10, certRecord.getClientCert());
ps.setTimestamp(11, getTimestampFromDate(certRecord.getExpiryTime()));
ps.setString(12, certRecord.getHostName());
ps.setString(13, certRecord.getSiaProvider());

affectedRows = executeUpdate(ps, caller);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
*/
package com.yahoo.athenz.common.server.util;

import com.yahoo.athenz.auth.util.StringUtils;
import jakarta.servlet.http.HttpServletRequest;
import com.google.common.net.InetAddresses;
import org.eclipse.jetty.http.HttpHeader;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class ServletRequestUtil {

public static final String LOOPBACK_ADDRESS = "127.0.0.1";
public static final String XFF_HEADER = "X-Forwarded-For";
public static final Pattern USER_AGENT_PATTERN = Pattern.compile("SIA-([^ ]+)([ ]+[^ ]*|$)");

/**
* Return the remote client IP address.
Expand All @@ -45,5 +51,23 @@ public static String getRemoteAddress(final HttpServletRequest request) {
}
return addr;
}

/**
* Return the SIA provider from user agent header, which is set by sia agent as request header.
* SIA agent header value is in the format 'SIA-<provider> <version> like 'SIA-FARGATE 1.32.0'.
* It extract just the provider name from the agent header value and return that.
* @param request http servlet request
* @return SIA provider
**/
public static String getSiaProvider(final HttpServletRequest request) {
final String userAgent = request.getHeader(HttpHeader.USER_AGENT.asString());
if (!StringUtils.isEmpty(userAgent)) {
Matcher matcher = USER_AGENT_PATTERN.matcher(userAgent.trim());
if (matcher.matches()) {
return matcher.group(1);
}
}
return null;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public void testX509CertRecord() {
certRecord.setLastNotifiedTime(now);
certRecord.setExpiryTime(now);
certRecord.setSvcDataUpdateTime(now);
certRecord.setSiaProvider("EKS");

assertEquals(certRecord.getService(), "cn");
assertEquals(certRecord.getProvider(), "ostk");
Expand All @@ -66,7 +67,7 @@ public void testX509CertRecord() {
"currentSerial='current-serial', currentTime=" + now + ", currentIP='current-ip', " +
"prevSerial='prev-serial', prevTime=" + now + ", prevIP='prev-ip', clientCert=true, " +
"lastNotifiedTime=" + now + ", lastNotifiedServer='server', expiryTime=" + now + ", " +
"hostName='host', svcDataUpdateTime=" + now + "}";
"hostName='host', svcDataUpdateTime=" + now + ", siaProvider='EKS'}";
assertEquals(certRecord.toString(), certStr);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ public void testGetX509CertRecordNullableColumns() throws Exception {
assertNull(certRecord.getLastNotifiedServer());
assertNull(certRecord.getExpiryTime());
assertNull(certRecord.getHostName());
assertNull(certRecord.getSiaProvider());

jdbcConn.close();
}
Expand Down Expand Up @@ -148,7 +149,6 @@ public void testInsertX509Record() throws Exception {
certRecord.setLastNotifiedServer("last-notified-server");
certRecord.setExpiryTime(now);
certRecord.setHostName("hostname");

Mockito.doReturn(1).when(mockPrepStmt).executeUpdate();
boolean requestSuccess = jdbcConn.insertX509CertRecord(certRecord);
assertTrue(requestSuccess);
Expand Down Expand Up @@ -223,9 +223,10 @@ public void testInsertX509RecordAlreadyExists() throws Exception {
Mockito.verify(mockPrepStmt, times(1)).setTimestamp(7, new java.sql.Timestamp(now.getTime()));
Mockito.verify(mockPrepStmt, times(1)).setString(8, "hostname");
Mockito.verify(mockPrepStmt, times(1)).setBoolean(9, false);
Mockito.verify(mockPrepStmt, times(1)).setString(10, "ostk");
Mockito.verify(mockPrepStmt, times(1)).setString(11, "instance-id");
Mockito.verify(mockPrepStmt, times(1)).setString(12, "cn");
Mockito.verify(mockPrepStmt, times(1)).setString(10, null);
Mockito.verify(mockPrepStmt, times(1)).setString(11, "ostk");
Mockito.verify(mockPrepStmt, times(1)).setString(12, "instance-id");
Mockito.verify(mockPrepStmt, times(1)).setString(13, "cn");

// common between insert/update so count is 2 times
Mockito.verify(mockPrepStmt, times(2)).setTimestamp(5, new java.sql.Timestamp(now.getTime()));
Expand Down Expand Up @@ -273,8 +274,20 @@ public void testUpdateX509Record() throws Exception {
assertTrue(requestSuccess);

verifyUpdateNonNullableColumns(now);

Mockito.verify(mockPrepStmt, times(1)).setString(1, "current-serial");
Mockito.verify(mockPrepStmt, times(1)).setTimestamp(2, new java.sql.Timestamp(now.getTime()));
Mockito.verify(mockPrepStmt, times(1)).setString(3, "current-ip");
Mockito.verify(mockPrepStmt, times(1)).setString(4, "prev-serial");
Mockito.verify(mockPrepStmt, times(1)).setTimestamp(5, new java.sql.Timestamp(now.getTime()));
Mockito.verify(mockPrepStmt, times(1)).setString(6, "prev-ip");
Mockito.verify(mockPrepStmt, times(1)).setTimestamp(7, new java.sql.Timestamp(now.getTime()));
Mockito.verify(mockPrepStmt, times(1)).setString(8, "hostname");
Mockito.verify(mockPrepStmt, times(1)).setBoolean(9, false);
Mockito.verify(mockPrepStmt, times(1)).setString(10, null);
Mockito.verify(mockPrepStmt, times(1)).setString(11, "ostk");
Mockito.verify(mockPrepStmt, times(1)).setString(12, "instance-id");
Mockito.verify(mockPrepStmt, times(1)).setString(13, "cn");

jdbcConn.close();
}
Expand All @@ -298,6 +311,7 @@ public void testUpdateX509RecordNullableColumns() throws Exception {
verifyUpdateNonNullableColumns(now);
Mockito.verify(mockPrepStmt, times(1)).setTimestamp(7, null);
Mockito.verify(mockPrepStmt, times(1)).setString(8, null);
Mockito.verify(mockPrepStmt, times(1)).setString(10, null);

jdbcConn.close();
}
Expand All @@ -311,9 +325,9 @@ private void verifyUpdateNonNullableColumns(Date now) throws SQLException {
Mockito.verify(mockPrepStmt, times(1)).setString(6, "prev-ip");
Mockito.verify(mockPrepStmt, times(1)).setBoolean(9, false);

Mockito.verify(mockPrepStmt, times(1)).setString(10, "ostk");
Mockito.verify(mockPrepStmt, times(1)).setString(11, "instance-id");
Mockito.verify(mockPrepStmt, times(1)).setString(12, "cn");
Mockito.verify(mockPrepStmt, times(1)).setString(11, "ostk");
Mockito.verify(mockPrepStmt, times(1)).setString(12, "instance-id");
Mockito.verify(mockPrepStmt, times(1)).setString(13, "cn");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

import jakarta.servlet.http.HttpServletRequest;

import org.eclipse.jetty.http.HttpHeader;
import org.mockito.Mockito;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNull;

import java.util.ArrayList;
import java.util.List;

public class ServletRequestUtilTest {

@Test
Expand Down Expand Up @@ -85,4 +90,33 @@ public void testGetRemoteAddressLoopBackMultipleXFFInvalidIP() {
assertEquals(ServletRequestUtil.getRemoteAddress(httpServletRequest), "127.0.0.1");
assertEquals(ServletRequestUtil.getRemoteAddress(httpServletRequest), "127.0.0.1");
}

@Test
public void testGetSiaAgentWithOutHeader() {
HttpServletRequest httpServletRequest = Mockito.mock(HttpServletRequest.class);
assertNull(ServletRequestUtil.getSiaProvider(httpServletRequest));
}

@Test(dataProvider = "dataGetSiaAgentWithHeader")
public void testGetSiaAgentWithHeader(String headerValue, String expected) {
HttpServletRequest httpServletRequest = Mockito.mock(HttpServletRequest.class);
Mockito.when(httpServletRequest.getHeader(HttpHeader.USER_AGENT.asString())).thenReturn(headerValue);
assertEquals(ServletRequestUtil.getSiaProvider(httpServletRequest), expected);
}

@DataProvider
private Object[][] dataGetSiaAgentWithHeader() {
List<Object[]> data = new ArrayList<>();
data.add(new Object[] {null, null});
data.add(new Object[] {"", null});
data.add(new Object[] {" ", null});
data.add(new Object[] {"SIA-FARGATE 1.32.0", "FARGATE"});
data.add(new Object[] {"SIA-FARGATE 1.32.0", "FARGATE"});
data.add(new Object[] {"SIA-FARGATE ", "FARGATE"});
data.add(new Object[] {"SIA-FARGATE ", "FARGATE"});
data.add(new Object[] {"SIA-FARGATE", "FARGATE"});
// don;t expect tab
data.add(new Object[] {"SIA-FARGATE\t1.32.0", "FARGATE\t1.32.0"});
return data.toArray(new Object[0][]);
}
}
2 changes: 2 additions & 0 deletions servers/zts/schema/updates/update-20241022.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE `zts_store`.`certificates`
ADD COLUMN `siaProvider` VARCHAR(256) NULL;
Binary file modified servers/zts/schema/zts_server.mwb
Binary file not shown.
9 changes: 5 additions & 4 deletions servers/zts/schema/zts_server.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- MySQL Script generated by MySQL Workbench
-- Thu Mar 25 20:56:17 2021
-- Tue Oct 22 22:21:58 2024
-- Model: New Model Version: 1.0
-- MySQL Workbench Forward Engineering

Expand Down Expand Up @@ -35,8 +35,9 @@ CREATE TABLE IF NOT EXISTS `zts_store`.`certificates` (
`lastNotifiedServer` VARCHAR(512) NULL,
`expiryTime` DATETIME(3) NULL,
`hostName` VARCHAR(512) NULL,
`siaProvider` VARCHAR(256) NULL,
PRIMARY KEY (`provider`, `instanceId`, `service`),
INDEX `idx_hostName` (`hostName` ASC))
INDEX `idx_hostName` (`hostName` ASC) VISIBLE)
ENGINE = InnoDB;


Expand Down Expand Up @@ -67,8 +68,8 @@ CREATE TABLE IF NOT EXISTS `zts_store`.`workloads` (
`updateTime` DATETIME(3) NULL DEFAULT CURRENT_TIMESTAMP(3),
`certExpiryTime` DATETIME(3) NOT NULL DEFAULT '1970-01-01 00:00:00.000',
PRIMARY KEY (`instanceId`, `ip`, `service`),
INDEX `idx_service` (`service` ASC),
INDEX `idx_ip` (`ip` ASC))
INDEX `idx_service` (`service` ASC) VISIBLE,
INDEX `idx_ip` (`ip` ASC) VISIBLE)
ENGINE = InnoDB;


Expand Down
2 changes: 2 additions & 0 deletions servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -3563,6 +3563,7 @@ X509CertRecord insertX509CertRecord(ResourceContext ctx, final String cn,
x509CertRecord.setExpiryTime(expirationDate);
x509CertRecord.setHostName(hostName);
x509CertRecord.setSvcDataUpdateTime(new Date());
x509CertRecord.setSiaProvider(ServletRequestUtil.getSiaProvider(ctx.request()));

// we must be able to update our database otherwise we will not be
// able to validate the certificate during refresh operations
Expand Down Expand Up @@ -4649,6 +4650,7 @@ X509CertRecord getValidatedX509CertRecord(ResourceContext ctx, final String prov
}
}

x509CertRecord.setSiaProvider(ServletRequestUtil.getSiaProvider(ctx.request()));
return x509CertRecord;
}

Expand Down
10 changes: 9 additions & 1 deletion servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import com.yahoo.rdl.Struct;
import com.yahoo.rdl.Timestamp;
import jakarta.servlet.ServletContext;
import org.eclipse.jetty.http.HttpHeader;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
Expand Down Expand Up @@ -271,11 +272,18 @@ public void shutdown() {
}

private ResourceContext createResourceContext(Principal principal) {
return createResourceContextWithUserAgent(principal, null);
}

private ResourceContext createResourceContextWithUserAgent(Principal principal, String userAgent) {
ServerResourceContext rsrcCtx = Mockito.mock(ServerResourceContext.class);
Mockito.when(rsrcCtx.principal()).thenReturn(principal);
Mockito.when(rsrcCtx.request()).thenReturn(mockServletRequest);
Mockito.when(mockServletRequest.getRemoteAddr()).thenReturn(MOCKCLIENTADDR);
Mockito.when(mockServletRequest.isSecure()).thenReturn(true);
if (null != userAgent) {
Mockito.when(mockServletRequest.getHeader(HttpHeader.USER_AGENT.asString())).thenReturn(userAgent);
}

RsrcCtxWrapper rsrcCtxWrapper = Mockito.mock(RsrcCtxWrapper.class);
Mockito.when(rsrcCtxWrapper.context()).thenReturn(rsrcCtx);
Expand Down Expand Up @@ -5768,7 +5776,7 @@ public void testPostInstanceRegisterInformationWithIPAndAccount() throws IOExcep
.setDomain("athenz").setService("production")
.setProvider("athenz.provider").setToken(false);

ResourceContext context = createResourceContext(null);
ResourceContext context = createResourceContextWithUserAgent(null, "SIA-FARGATE 1.32.0");

Response response = ztsImpl.postInstanceRegisterInformation(context, info);
assertEquals(response.getStatus(), 201);
Expand Down

0 comments on commit 483e0ab

Please sign in to comment.