Skip to content

Commit

Permalink
fix: fix pg proxy issues, and remove hard coded DSNs (#3501)
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartwdouglas authored Nov 25, 2024
1 parent 6b9628a commit d4720e4
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 65 deletions.
12 changes: 1 addition & 11 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -1063,18 +1063,8 @@ func (s *Service) CreateDeployment(ctx context.Context, req *connect.Request[ftl
return nil, fmt.Errorf("invalid module schema: %w", err)
}

for _, d := range module.Decls {
if db, ok := d.(*schema.Database); ok && db.Runtime != nil {
key := dsnSecretKey(module.Name, db.Name)

if err := s.sm.Set(ctx, configuration.NewRef(module.Name, key), db.Runtime.DSN); err != nil {
return nil, fmt.Errorf("could not set database secret %s: %w", key, err)
}
logger.Infof("Database declaration: %s -> %s type %s", db.Name, db.Runtime.DSN, db.Type)
}
}

dkey, err := s.dal.CreateDeployment(ctx, ms.Runtime.Language, module, artefacts)

if err != nil {
logger.Errorf(err, "Could not create deployment")
return nil, fmt.Errorf("could not create deployment: %w", err)
Expand Down
5 changes: 3 additions & 2 deletions backend/provisioner/provisioner_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (
"fmt"
"testing"

in "github.com/TBD54566975/ftl/internal/integration"
"github.com/alecthomas/assert/v2"

in "github.com/TBD54566975/ftl/internal/integration"
)

func TestDeploymentThrougDevProvisionerCreatePostgresDB(t *testing.T) {
func TestDeploymentThroughDevProvisionerCreatePostgresDB(t *testing.T) {
in.Run(t,
in.WithFTLConfig("./ftl-project.toml"),
in.CopyModule("echo"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- migrate:up
CREATE TABLE messages( message TEXT );
-- migrate:down
DROP TABLE messages;
11 changes: 2 additions & 9 deletions backend/provisioner/testdata/go/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,12 @@ func (EchoDBConfig) Name() string { return "echodb" }
//
//ftl:verb export
func Echo(ctx context.Context, req string, db ftl.DatabaseHandle[EchoDBConfig]) (string, error) {
_, err := db.Get(ctx).Exec(`CREATE TABLE IF NOT EXISTS messages(
message TEXT
);`)
_, err := db.Get(ctx).Exec(`INSERT INTO messages (message) VALUES ($1);`, req)
if err != nil {
return "", err
}

_, err = db.Get(ctx).Exec(`INSERT INTO messages (message) VALUES ($1);`, req)
if err != nil {
return "", err
}

rows, err := db.Get(ctx).Query(`SELECT message FROM messages;`)
rows, err := db.Get(ctx).Query(`SELECT DISTINCT message FROM messages;`)
if err != nil {
return "", err
}
Expand Down
9 changes: 0 additions & 9 deletions ftl-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,9 @@ disable-ide-integration = true
key = "inline://InZhbHVlIg"

[modules]
[modules.database]
[modules.database.secrets]
FTL_DSN_DATABASE_TESTDB = "inline://InBvc3RncmVzOi8vMTI3LjAuMC4xOjE1NDMyL2RhdGFiYXNlX3Rlc3RkYj9zc2xtb2RlPWRpc2FibGVcdTAwMjZ1c2VyPXBvc3RncmVzXHUwMDI2cGFzc3dvcmQ9c2VjcmV0Ig"
[modules.echo]
[modules.echo.configuration]
default = "inline://ImFub255bW91cyI"
[modules.mysql]
[modules.mysql.secrets]
FTL_DSN_MYSQL_TESTDB = "inline://InJvb3Q6c2VjcmV0QHRjcCgxMjcuMC4wLjE6MTMzMDYpL215c3FsX3Rlc3RkYj9hbGxvd05hdGl2ZVBhc3N3b3Jkcz1UcnVlIg"
[modules.test]
[modules.test.configuration]
[modules.test.secrets]

[commands]
startup = ["echo 'FTL startup command ⚡️'"]
76 changes: 45 additions & 31 deletions internal/pgproxy/pgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ func (p *PgProxy) Start(ctx context.Context, started chan<- Started) error {
// It will block until the connection is closed.
func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstructor) {
defer conn.Close()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

logger := log.FromContext(ctx)
logger.Infof("new connection established: %s", conn.RemoteAddr())
logger.Debugf("new connection established: %s", conn.RemoteAddr())

backend, startup, err := connectBackend(ctx, conn)
if err != nil {
Expand All @@ -90,30 +92,33 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr
logger.Infof("client disconnected without startup message: %s", conn.RemoteAddr())
return
}
logger.Debugf("startup message: %+v", startup)
logger.Debugf("backend connected: %s", conn.RemoteAddr())
logger.Tracef("startup message: %+v", startup)
logger.Tracef("backend connected: %s", conn.RemoteAddr())

frontend, err := connectFrontend(ctx, connectionFn, startup)
hijacked, err := connectFrontend(ctx, connectionFn, startup)
if err != nil {
// try again, in case there was a credential rotation
logger.Warnf("failed to connect frontend: %s, trying again", err)
logger.Debugf("failed to connect frontend: %s, trying again", err)

frontend, err = connectFrontend(ctx, connectionFn, startup)
hijacked, err = connectFrontend(ctx, connectionFn, startup)
if err != nil {
handleBackendError(ctx, backend, err)
return
}
}
backend.Send(&pgproto3.AuthenticationOk{})
logger.Debugf("frontend connected")
for key, value := range hijacked.ParameterStatuses {
backend.Send(&pgproto3.ParameterStatus{Name: key, Value: value})
}

backend.Send(&pgproto3.AuthenticationOk{})
backend.Send(&pgproto3.ReadyForQuery{})
backend.Send(&pgproto3.ReadyForQuery{TxStatus: 'I'})
if err := backend.Flush(); err != nil {
logger.Errorf(err, "failed to flush backend authentication ok")
return
}

if err := proxy(ctx, backend, frontend); err != nil {
if err := proxy(ctx, backend, hijacked.Frontend); err != nil {
logger.Warnf("disconnecting %s due to: %s", conn.RemoteAddr(), err)
return
}
Expand Down Expand Up @@ -171,7 +176,7 @@ func connectBackend(ctx context.Context, conn net.Conn) (*pgproto3.Backend, *pgp
}
}

func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) {
func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgconn.HijackedConn, error) {
dsn, err := connectionFn(ctx, startup.Parameters)
if err != nil {
return nil, fmt.Errorf("failed to construct dsn: %w", err)
Expand All @@ -181,59 +186,68 @@ func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *
if err != nil {
return nil, fmt.Errorf("failed to connect to backend: %w", err)
}
frontend := pgproto3.NewFrontend(conn.Conn(), conn.Conn())

return frontend, nil
hijacked, err := conn.Hijack()
if err != nil {
return nil, fmt.Errorf("failed to hijack backend: %w", err)
}
return hijacked, nil
}

func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error {
logger := log.FromContext(ctx)
frontendMessages := make(chan pgproto3.BackendMessage)
backendMessages := make(chan pgproto3.FrontendMessage)
errors := make(chan error, 2)

go func() {
for {
msg, err := backend.Receive()
select {
case <-ctx.Done():
return
default:
}
if err != nil {
errors <- fmt.Errorf("failed to receive backend message: %w", err)
return
}
logger.Tracef("backend message: %T", msg)
backendMessages <- msg
frontend.Send(msg)
err = frontend.Flush()
if err != nil {
errors <- fmt.Errorf("failed to receive backend message: %w", err)
return
}
if _, ok := msg.(*pgproto3.Terminate); ok {
return
}
}
}()

go func() {
for {
msg, err := frontend.Receive()
select {
case <-ctx.Done():
return
default:
}
if err != nil {
errors <- fmt.Errorf("failed to receive frontend message: %w", err)
return
}
logger.Tracef("frontend message: %T", msg)
frontendMessages <- msg
backend.Send(msg)
err = backend.Flush()
if err != nil {
errors <- fmt.Errorf("failed to receive backend message: %w", err)
return
}
}
}()

for {
select {
case <-ctx.Done():
return fmt.Errorf("context done: %w", ctx.Err())
case msg := <-backendMessages:
frontend.Send(msg)
if err := frontend.Flush(); err != nil {
return fmt.Errorf("failed to flush frontend message: %w", err)
}

if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
case msg := <-frontendMessages:
backend.Send(msg)
if err := backend.Flush(); err != nil {
return fmt.Errorf("failed to flush backend message: %w", err)
}
case err := <-errors:
return err
}
Expand Down
8 changes: 6 additions & 2 deletions internal/pgproxy/pgproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"net"
"testing"

"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"

"github.com/TBD54566975/ftl/internal/dev"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/pgproxy"
"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"
)

func TestPgProxy(t *testing.T) {
Expand Down Expand Up @@ -48,6 +49,9 @@ func TestPgProxy(t *testing.T) {
assert.NoError(t, frontend.Flush())

assertResponseType[*pgproto3.AuthenticationOk](t, frontend)
for range 13 {
assertResponseType[*pgproto3.ParameterStatus](t, frontend)
}
assertResponseType[*pgproto3.ReadyForQuery](t, frontend)
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import io.quarkus.agroal.spi.JdbcDataSourceBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.GeneratedResourceBuildItem;
import io.quarkus.deployment.builditem.SystemPropertyBuildItem;
import xyz.block.ftl.runtime.FTLDatasourceCredentials;
import xyz.block.ftl.runtime.FTLRecorder;
import xyz.block.ftl.runtime.config.FTLConfigSource;
import xyz.block.ftl.v1.ModuleContextResponse;
import xyz.block.ftl.v1.schema.Database;
import xyz.block.ftl.v1.schema.Decl;

Expand All @@ -21,10 +25,12 @@ public class DatasourceProcessor {
private static final Logger log = Logger.getLogger(DatasourceProcessor.class);

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public SchemaContributorBuildItem registerDatasources(
List<JdbcDataSourceBuildItem> datasources,
BuildProducer<SystemPropertyBuildItem> systemPropProducer,
BuildProducer<GeneratedResourceBuildItem> generatedResourceBuildItemBuildProducer) {
BuildProducer<GeneratedResourceBuildItem> generatedResourceBuildItemBuildProducer,
FTLRecorder recorder) {
log.infof("Processing %d datasource annotations into decls", datasources.size());
List<Decl> decls = new ArrayList<>();
List<String> namedDatasources = new ArrayList<>();
Expand All @@ -37,6 +43,11 @@ public SchemaContributorBuildItem registerDatasources(
// FTL and quarkus use slightly different names
dbKind = "postgres";
}
if (dbKind.equals("mysql")) {
recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.MYSQL);
} else {
recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.POSTGRES);
}
//default name is <default> which is not a valid name
String sanitisedName = ds.getName().replace("<", "").replace(">", "");
//we use a dynamic credentials provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -36,6 +38,8 @@ public class FTLController implements LeaseClient {

private static volatile FTLController controller;

private final Map<String, ModuleContextResponse.DBType> databases = new ConcurrentHashMap<>();

/**
* TODO: look at how init should work, this is terrible and will break dev mode
*/
Expand Down Expand Up @@ -71,6 +75,10 @@ public static FTLController instance() {
verbService = VerbServiceGrpc.newStub(channel);
}

public void registerDatabase(String name, ModuleContextResponse.DBType type) {
databases.put(name, type);
}

public byte[] getSecret(String secretName) {
var context = getModuleContext();
if (context.containsSecrets(secretName)) {
Expand All @@ -88,6 +96,10 @@ public byte[] getConfig(String secretName) {
}

public Datasource getDatasource(String name) {
if (databases.get(name) == ModuleContextResponse.DBType.POSTGRES) {
var proxyAddress = System.getenv("FTL_PROXY_POSTGRES_ADDRESS");
return new Datasource("jdbc:postgresql://" + proxyAddress + "/" + name, "ftl", "ftl");
}
List<ModuleContextResponse.DSN> databasesList = getModuleContext().getDatabasesList();
for (var i : databasesList) {
if (i.getName().equals(name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import xyz.block.ftl.runtime.http.FTLHttpHandler;
import xyz.block.ftl.runtime.http.HTTPVerbInvoker;
import xyz.block.ftl.v1.CallRequest;
import xyz.block.ftl.v1.ModuleContextResponse;

@Recorder
public class FTLRecorder {
Expand Down Expand Up @@ -171,4 +172,8 @@ public void run() {
}
});
}

public void registerDatabase(String dbKind, ModuleContextResponse.DBType name) {
FTLController.instance().registerDatabase(dbKind, name);
}
}

0 comments on commit d4720e4

Please sign in to comment.