diff --git a/internal/pgproxy/pgproxy.go b/internal/pgproxy/pgproxy.go index ae9b5498a2..35ed1f314a 100644 --- a/internal/pgproxy/pgproxy.go +++ b/internal/pgproxy/pgproxy.go @@ -81,7 +81,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr defer cancel() logger := log.FromContext(ctx) - logger.Infof("new connection established: %s", conn.RemoteAddr()) + logger.Debugf("new connection established: %s", conn.RemoteAddr()) backend, startup, err := connectBackend(ctx, conn) if err != nil { @@ -92,24 +92,27 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr logger.Infof("client disconnected without startup message: %s", conn.RemoteAddr()) return } - logger.Debugf("startup message: %+v", startup) - logger.Debugf("backend connected: %s", conn.RemoteAddr()) + logger.Tracef("startup message: %+v", startup) + logger.Tracef("backend connected: %s", conn.RemoteAddr()) - frontend, err := connectFrontend(ctx, connectionFn, startup) + frontend, hijacked, err := connectFrontend(ctx, connectionFn, startup) if err != nil { // try again, in case there was a credential rotation - logger.Warnf("failed to connect frontend: %s, trying again", err) + logger.Debugf("failed to connect frontend: %s, trying again", err) - frontend, err = connectFrontend(ctx, connectionFn, startup) + frontend, hijacked, err = connectFrontend(ctx, connectionFn, startup) if err != nil { handleBackendError(ctx, backend, err) return } } + backend.Send(&pgproto3.AuthenticationOk{}) logger.Debugf("frontend connected") + for key, value := range hijacked.ParameterStatuses { + backend.Send(&pgproto3.ParameterStatus{Name: key, Value: value}) + } - backend.Send(&pgproto3.AuthenticationOk{}) - backend.Send(&pgproto3.ReadyForQuery{}) + backend.Send(&pgproto3.ReadyForQuery{TxStatus: 'I'}) if err := backend.Flush(); err != nil { logger.Errorf(err, "failed to flush backend authentication ok") return @@ -173,19 +176,22 @@ func connectBackend(ctx context.Context, conn net.Conn) (*pgproto3.Backend, *pgp } } -func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) { +func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, *pgconn.HijackedConn, error) { dsn, err := connectionFn(ctx, startup.Parameters) if err != nil { - return nil, fmt.Errorf("failed to construct dsn: %w", err) + return nil, nil, fmt.Errorf("failed to construct dsn: %w", err) } conn, err := pgconn.Connect(ctx, dsn) if err != nil { - return nil, fmt.Errorf("failed to connect to backend: %w", err) + return nil, nil, fmt.Errorf("failed to connect to backend: %w", err) } - frontend := pgproto3.NewFrontend(conn.Conn(), conn.Conn()) - - return frontend, nil + hijacked, err := conn.Hijack() + if err != nil { + return nil, nil, fmt.Errorf("failed to hijack backend: %w", err) + } + frontend := hijacked.Frontend + return frontend, hijacked, nil } func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error { diff --git a/internal/pgproxy/pgproxy_test.go b/internal/pgproxy/pgproxy_test.go index 2ef6633e5a..5d73f0d7f3 100644 --- a/internal/pgproxy/pgproxy_test.go +++ b/internal/pgproxy/pgproxy_test.go @@ -5,11 +5,12 @@ import ( "net" "testing" + "github.com/alecthomas/assert/v2" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/TBD54566975/ftl/internal/dev" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/pgproxy" - "github.com/alecthomas/assert/v2" - "github.com/jackc/pgx/v5/pgproto3" ) func TestPgProxy(t *testing.T) { @@ -48,6 +49,9 @@ func TestPgProxy(t *testing.T) { assert.NoError(t, frontend.Flush()) assertResponseType[*pgproto3.AuthenticationOk](t, frontend) + for range 13 { + assertResponseType[*pgproto3.ParameterStatus](t, frontend) + } assertResponseType[*pgproto3.ReadyForQuery](t, frontend) }) diff --git a/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java b/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java index a198a3d697..a629504d5e 100644 --- a/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java +++ b/jvm-runtime/ftl-runtime/common/deployment/src/main/java/xyz/block/ftl/deployment/DatasourceProcessor.java @@ -9,10 +9,14 @@ import io.quarkus.agroal.spi.JdbcDataSourceBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; import io.quarkus.deployment.builditem.GeneratedResourceBuildItem; import io.quarkus.deployment.builditem.SystemPropertyBuildItem; import xyz.block.ftl.runtime.FTLDatasourceCredentials; +import xyz.block.ftl.runtime.FTLRecorder; import xyz.block.ftl.runtime.config.FTLConfigSource; +import xyz.block.ftl.v1.ModuleContextResponse; import xyz.block.ftl.v1.schema.Database; import xyz.block.ftl.v1.schema.Decl; @@ -21,10 +25,12 @@ public class DatasourceProcessor { private static final Logger log = Logger.getLogger(DatasourceProcessor.class); @BuildStep + @Record(ExecutionTime.STATIC_INIT) public SchemaContributorBuildItem registerDatasources( List datasources, BuildProducer systemPropProducer, - BuildProducer generatedResourceBuildItemBuildProducer) { + BuildProducer generatedResourceBuildItemBuildProducer, + FTLRecorder recorder) { log.infof("Processing %d datasource annotations into decls", datasources.size()); List decls = new ArrayList<>(); List namedDatasources = new ArrayList<>(); @@ -37,6 +43,11 @@ public SchemaContributorBuildItem registerDatasources( // FTL and quarkus use slightly different names dbKind = "postgres"; } + if (dbKind.equals("mysql")) { + recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.MYSQL); + } else { + recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.POSTGRES); + } //default name is which is not a valid name String sanitisedName = ds.getName().replace("<", "").replace(">", ""); //we use a dynamic credentials provider diff --git a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java index f51ca8d98e..ccc429e35c 100644 --- a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java +++ b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLController.java @@ -5,7 +5,9 @@ import java.time.Duration; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; @@ -36,6 +38,8 @@ public class FTLController implements LeaseClient { private static volatile FTLController controller; + private final Map databases = new ConcurrentHashMap<>(); + /** * TODO: look at how init should work, this is terrible and will break dev mode */ @@ -71,6 +75,10 @@ public static FTLController instance() { verbService = VerbServiceGrpc.newStub(channel); } + public void registerDatabase(String name, ModuleContextResponse.DBType type) { + databases.put(name, type); + } + public byte[] getSecret(String secretName) { var context = getModuleContext(); if (context.containsSecrets(secretName)) { @@ -88,6 +96,10 @@ public byte[] getConfig(String secretName) { } public Datasource getDatasource(String name) { + if (databases.get(name) == ModuleContextResponse.DBType.POSTGRES) { + var proxyAddress = System.getenv("FTL_PROXY_POSTGRES_ADDRESS"); + return new Datasource("jdbc:postgresql://" + proxyAddress + "/" + name, "ftl", "ftl"); + } List databasesList = getModuleContext().getDatabasesList(); for (var i : databasesList) { if (i.getName().equals(name)) { diff --git a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java index 1ec9cf0ca1..f3c497777d 100644 --- a/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java +++ b/jvm-runtime/ftl-runtime/common/runtime/src/main/java/xyz/block/ftl/runtime/FTLRecorder.java @@ -18,6 +18,7 @@ import xyz.block.ftl.runtime.http.FTLHttpHandler; import xyz.block.ftl.runtime.http.HTTPVerbInvoker; import xyz.block.ftl.v1.CallRequest; +import xyz.block.ftl.v1.ModuleContextResponse; @Recorder public class FTLRecorder { @@ -171,4 +172,8 @@ public void run() { } }); } + + public void registerDatabase(String dbKind, ModuleContextResponse.DBType name) { + FTLController.instance().registerDatabase(dbKind, name); + } }