Skip to content

Commit

Permalink
make it work with Java
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartwdouglas committed Nov 25, 2024
1 parent cd4a310 commit 37dec93
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 17 deletions.
34 changes: 20 additions & 14 deletions internal/pgproxy/pgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr
defer cancel()

logger := log.FromContext(ctx)
logger.Infof("new connection established: %s", conn.RemoteAddr())
logger.Debugf("new connection established: %s", conn.RemoteAddr())

backend, startup, err := connectBackend(ctx, conn)
if err != nil {
Expand All @@ -92,24 +92,27 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr
logger.Infof("client disconnected without startup message: %s", conn.RemoteAddr())
return
}
logger.Debugf("startup message: %+v", startup)
logger.Debugf("backend connected: %s", conn.RemoteAddr())
logger.Tracef("startup message: %+v", startup)
logger.Tracef("backend connected: %s", conn.RemoteAddr())

frontend, err := connectFrontend(ctx, connectionFn, startup)
frontend, hijacked, err := connectFrontend(ctx, connectionFn, startup)
if err != nil {
// try again, in case there was a credential rotation
logger.Warnf("failed to connect frontend: %s, trying again", err)
logger.Debugf("failed to connect frontend: %s, trying again", err)

frontend, err = connectFrontend(ctx, connectionFn, startup)
frontend, hijacked, err = connectFrontend(ctx, connectionFn, startup)
if err != nil {
handleBackendError(ctx, backend, err)
return
}
}
backend.Send(&pgproto3.AuthenticationOk{})
logger.Debugf("frontend connected")
for key, value := range hijacked.ParameterStatuses {
backend.Send(&pgproto3.ParameterStatus{Name: key, Value: value})
}

backend.Send(&pgproto3.AuthenticationOk{})
backend.Send(&pgproto3.ReadyForQuery{})
backend.Send(&pgproto3.ReadyForQuery{TxStatus: 'I'})
if err := backend.Flush(); err != nil {
logger.Errorf(err, "failed to flush backend authentication ok")
return
Expand Down Expand Up @@ -173,19 +176,22 @@ func connectBackend(ctx context.Context, conn net.Conn) (*pgproto3.Backend, *pgp
}
}

func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) {
func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, *pgconn.HijackedConn, error) {
dsn, err := connectionFn(ctx, startup.Parameters)
if err != nil {
return nil, fmt.Errorf("failed to construct dsn: %w", err)
return nil, nil, fmt.Errorf("failed to construct dsn: %w", err)
}

conn, err := pgconn.Connect(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to backend: %w", err)
return nil, nil, fmt.Errorf("failed to connect to backend: %w", err)
}
frontend := pgproto3.NewFrontend(conn.Conn(), conn.Conn())

return frontend, nil
hijacked, err := conn.Hijack()
if err != nil {
return nil, nil, fmt.Errorf("failed to hijack backend: %w", err)
}
frontend := hijacked.Frontend
return frontend, hijacked, nil
}

func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error {
Expand Down
8 changes: 6 additions & 2 deletions internal/pgproxy/pgproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"net"
"testing"

"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"

"github.com/TBD54566975/ftl/internal/dev"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/pgproxy"
"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"
)

func TestPgProxy(t *testing.T) {
Expand Down Expand Up @@ -48,6 +49,9 @@ func TestPgProxy(t *testing.T) {
assert.NoError(t, frontend.Flush())

assertResponseType[*pgproto3.AuthenticationOk](t, frontend)
for range 13 {
assertResponseType[*pgproto3.ParameterStatus](t, frontend)
}
assertResponseType[*pgproto3.ReadyForQuery](t, frontend)
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import io.quarkus.agroal.spi.JdbcDataSourceBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.GeneratedResourceBuildItem;
import io.quarkus.deployment.builditem.SystemPropertyBuildItem;
import xyz.block.ftl.runtime.FTLDatasourceCredentials;
import xyz.block.ftl.runtime.FTLRecorder;
import xyz.block.ftl.runtime.config.FTLConfigSource;
import xyz.block.ftl.v1.ModuleContextResponse;
import xyz.block.ftl.v1.schema.Database;
import xyz.block.ftl.v1.schema.Decl;

Expand All @@ -21,10 +25,12 @@ public class DatasourceProcessor {
private static final Logger log = Logger.getLogger(DatasourceProcessor.class);

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public SchemaContributorBuildItem registerDatasources(
List<JdbcDataSourceBuildItem> datasources,
BuildProducer<SystemPropertyBuildItem> systemPropProducer,
BuildProducer<GeneratedResourceBuildItem> generatedResourceBuildItemBuildProducer) {
BuildProducer<GeneratedResourceBuildItem> generatedResourceBuildItemBuildProducer,
FTLRecorder recorder) {
log.infof("Processing %d datasource annotations into decls", datasources.size());
List<Decl> decls = new ArrayList<>();
List<String> namedDatasources = new ArrayList<>();
Expand All @@ -37,6 +43,11 @@ public SchemaContributorBuildItem registerDatasources(
// FTL and quarkus use slightly different names
dbKind = "postgres";
}
if (dbKind.equals("mysql")) {
recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.MYSQL);
} else {
recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.POSTGRES);
}
//default name is <default> which is not a valid name
String sanitisedName = ds.getName().replace("<", "").replace(">", "");
//we use a dynamic credentials provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -36,6 +38,8 @@ public class FTLController implements LeaseClient {

private static volatile FTLController controller;

private final Map<String, ModuleContextResponse.DBType> databases = new ConcurrentHashMap<>();

/**
* TODO: look at how init should work, this is terrible and will break dev mode
*/
Expand Down Expand Up @@ -71,6 +75,10 @@ public static FTLController instance() {
verbService = VerbServiceGrpc.newStub(channel);
}

public void registerDatabase(String name, ModuleContextResponse.DBType type) {
databases.put(name, type);
}

public byte[] getSecret(String secretName) {
var context = getModuleContext();
if (context.containsSecrets(secretName)) {
Expand All @@ -88,6 +96,10 @@ public byte[] getConfig(String secretName) {
}

public Datasource getDatasource(String name) {
if (databases.get(name) == ModuleContextResponse.DBType.POSTGRES) {
var proxyAddress = System.getenv("FTL_PROXY_POSTGRES_ADDRESS");
return new Datasource("jdbc:postgresql://" + proxyAddress + "/" + name, "ftl", "ftl");
}
List<ModuleContextResponse.DSN> databasesList = getModuleContext().getDatabasesList();
for (var i : databasesList) {
if (i.getName().equals(name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import xyz.block.ftl.runtime.http.FTLHttpHandler;
import xyz.block.ftl.runtime.http.HTTPVerbInvoker;
import xyz.block.ftl.v1.CallRequest;
import xyz.block.ftl.v1.ModuleContextResponse;

@Recorder
public class FTLRecorder {
Expand Down Expand Up @@ -171,4 +172,8 @@ public void run() {
}
});
}

public void registerDatabase(String dbKind, ModuleContextResponse.DBType name) {
FTLController.instance().registerDatabase(dbKind, name);
}
}

0 comments on commit 37dec93

Please sign in to comment.