Skip to content

Commit

Permalink
(task) refactor pool/driver struct (#94)
Browse files Browse the repository at this point in the history
* separate pool & driver setup
  • Loading branch information
eoinmcafee00 authored Apr 21, 2022
1 parent 709d595 commit f687dbf
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 437 deletions.
5 changes: 3 additions & 2 deletions command/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type (
MarketType string `json:"market_type,omitempty" yaml:"market_type"`
RootDirectory string `json:"root_directory,omitempty" yaml:"root_directory"`
Hibernate bool `json:"hibernate,omitempty"`
User string `json:"user,omitempty"`
}

VMFusion struct {
Expand Down Expand Up @@ -122,9 +123,9 @@ func (s *Instance) UnmarshalJSON(data []byte) error {
return err
}
switch s.Type {
case "amazon":
case "amazon", "aws":
s.Spec = new(Amazon)
case "gcp":
case "google", "gcp":
s.Spec = new(Google)
case "vmfusion":
s.Spec = new(VMFusion)
Expand Down
9 changes: 5 additions & 4 deletions engine/linter/linter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ func Test_checkPools(t *testing.T) {
}

func DummyPool(name, runnerName string) drivers.Pool {
var pool, err = amazon.New(
amazon.WithRunnerName(runnerName),
amazon.WithName(name), // pool name
)
var pool drivers.Pool
pool.Name = name
pool.RunnerName = runnerName
var driver, err = amazon.New()
pool.Driver = driver
if err != nil {
return pool
}
Expand Down
9 changes: 5 additions & 4 deletions internal/certs/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ func Generate(runnerName string) (*types.InstanceCreateOpts, error) {
return nil, fmt.Errorf("failed to generate tls certificate: %w", err)
}
return &types.InstanceCreateOpts{
CACert: ca.Cert,
CAKey: ca.Key,
TLSCert: tlsCert.Cert,
TLSKey: tlsCert.Key,
CACert: ca.Cert,
CAKey: ca.Key,
TLSCert: tlsCert.Cert,
TLSKey: tlsCert.Key,
RunnerName: runnerName,
}, nil
}
66 changes: 25 additions & 41 deletions internal/drivers/amazon/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,13 @@ package amazon
import (
"os"

"github.com/drone-runners/drone-runner-aws/oshelp"

"github.com/sirupsen/logrus"
)

type Option func(*provider)

// WithRunnerName returns an option to set the runner name
func WithRunnerName(name string) Option {
return func(p *provider) {
p.runnerName = name
}
}

func WithOs(machineOs string) Option {
return func(p *provider) {
p.os = machineOs
}
}

func WithArch(arch string) Option {
return func(p *provider) {
p.arch = arch
}
}

func WithAccessKeyID(accessKeyID string) Option {
return func(p *provider) {
p.accessKeyID = accessKeyID
Expand Down Expand Up @@ -90,9 +73,17 @@ func WithSecurityGroup(group ...string) Option {
}

// WithSize returns an option to set the instance size.
func WithSize(size string) Option {
func WithSize(size, arch string) Option {
return func(p *provider) {
p.size = size
// set default instance type if not provided
if p.size == "" {
if arch == "arm64" {
p.size = "a1.medium"
} else {
p.size = "t3.nano"
}
}
}
}

Expand Down Expand Up @@ -163,27 +154,6 @@ func WithMarketType(t string) Option {
}
}

// WithName returns an option to set the instance name.
func WithName(name string) Option {
return func(p *provider) {
p.name = name
}
}

// WithLimit the total number of running servers. If exceeded block or error.
func WithLimit(limit int) Option {
return func(p *provider) {
p.limit = limit
}
}

// WithPool total number of warm instances in the pool at all times
func WithPool(pool int) Option {
return func(p *provider) {
p.pool = pool
}
}

// WithZone returns an option to set the zone.
func WithZone(zone string) Option {
return func(p *provider) {
Expand All @@ -203,3 +173,17 @@ func WithHibernate(hibernate bool) Option {
p.hibernate = hibernate
}
}

func WithUser(user, platform string) Option {
return func(p *provider) {
p.user = user
// set the default ssh user. this user account is responsible for executing the pipeline script.
if p.user == "" {
if platform == oshelp.OSWindows {
p.user = "Administrator"
} else {
p.user = "root"
}
}
}
}
48 changes: 16 additions & 32 deletions internal/drivers/amazon/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,24 @@ import (
"github.com/cenkalti/backoff/v4"
)

func (p *provider) GetProviderName() string {
func (p *provider) ProviderName() string {
return string(types.ProviderAmazon)
}

func (p *provider) GetName() string {
return p.name
}

func (p *provider) GetInstanceType() string {
func (p *provider) InstanceType() string {
return p.image
}

func (p *provider) GetOS() string {
return p.os
}

func (p *provider) GetRootDir() string {
func (p *provider) RootDir() string {
return p.rootDir
}

func (p *provider) GetMaxSize() int {
return p.limit
}

func (p *provider) GetMinSize() int {
return p.pool
}

func (p *provider) CanHibernate() bool {
return p.hibernate
}

// PingProvider checks that we can log into EC2, and the regions respond
func (p *provider) PingProvider(ctx context.Context) error {
// Ping checks that we can log into EC2, and the regions respond
func (p *provider) Ping(ctx context.Context) error {
client := p.service

allRegions := true
Expand All @@ -67,12 +51,12 @@ func (p *provider) Create(ctx context.Context, opts *types.InstanceCreateOpts) (

logr := logger.FromContext(ctx).
WithField("provider", types.ProviderAmazon).
WithField("ami", p.GetInstanceType()).
WithField("pool", p.name).
WithField("ami", p.InstanceType()).
WithField("pool", opts.PoolName).
WithField("region", p.region).
WithField("image", p.image).
WithField("size", p.size)
var name = fmt.Sprintf(p.runnerName+"-"+p.name+"-%d", time.Now().Unix())
var name = fmt.Sprintf(opts.RunnerName+"-"+opts.PoolName+"-%d", time.Now().Unix())

var tags = map[string]string{
"Name": name,
Expand All @@ -98,7 +82,7 @@ func (p *provider) Create(ctx context.Context, opts *types.InstanceCreateOpts) (
IamInstanceProfile: iamProfile,
UserData: aws.String(
base64.StdEncoding.EncodeToString(
[]byte(userdata.Generate(p.userData, p.os, p.arch, opts)),
[]byte(userdata.Generate(p.userData, opts)),
),
),
NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
Expand Down Expand Up @@ -182,13 +166,13 @@ func (p *provider) Create(ctx context.Context, opts *types.InstanceCreateOpts) (
Name: instanceID,
Provider: types.ProviderAmazon,
State: types.StateCreated,
Pool: p.name,
Pool: opts.PoolName,
Image: p.image,
Zone: p.availabilityZone,
Region: p.region,
Size: p.size,
Platform: p.os,
Arch: p.arch,
Platform: opts.OS,
Arch: opts.Arch,
Address: instanceIP,
CACert: opts.CACert,
CAKey: opts.CAKey,
Expand Down Expand Up @@ -234,10 +218,10 @@ func (p *provider) Destroy(ctx context.Context, instanceIDs ...string) (err erro
return
}

func (p *provider) Hibernate(ctx context.Context, instanceID string) error {
func (p *provider) Hibernate(ctx context.Context, instanceID, poolName string) error {
logr := logger.FromContext(ctx).
WithField("provider", types.ProviderAmazon).
WithField("pool", p.name).
WithField("pool", poolName).
WithField("instanceID", instanceID)

client := p.service
Expand All @@ -255,12 +239,12 @@ func (p *provider) Hibernate(ctx context.Context, instanceID string) error {
return nil
}

func (p *provider) Start(ctx context.Context, instanceID string) (string, error) {
func (p *provider) Start(ctx context.Context, instanceID, poolName string) (string, error) {
client := p.service

logr := logger.FromContext(ctx).
WithField("provider", types.ProviderAmazon).
WithField("pool", p.name).
WithField("pool", poolName).
WithField("instanceID", instanceID)

amazonInstance, err := p.getInstance(ctx, instanceID)
Expand Down
46 changes: 2 additions & 44 deletions internal/drivers/amazon/provider.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
package amazon

import (
"github.com/drone-runners/drone-runner-aws/internal/drivers"
"github.com/drone-runners/drone-runner-aws/oshelp"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/drone-runners/drone-runner-aws/internal/drivers"
)

// provider is a struct that implements drivers.Pool interface
type provider struct {
name string
runnerName string

spotInstance bool
region string
availabilityZone string
Expand All @@ -24,8 +19,6 @@ type provider struct {
secretAccessKey string
keyPairName string

os string
arch string
rootDir string

image string
Expand All @@ -43,50 +36,23 @@ type provider struct {
iamProfileArn string
hibernate bool

// pool size data
pool int
limit int

service *ec2.EC2
}

func New(opts ...Option) (drivers.Pool, error) {
func New(opts ...Option) (drivers.Driver, error) {
p := new(provider)
for _, opt := range opts {
opt(p)
}
if p.retries == 0 {
p.retries = 10
}
if p.pool < 0 {
p.pool = 0
}
if p.limit <= 0 {
p.limit = 100
}
if p.pool > p.limit {
p.limit = p.pool
}
if p.region == "" {
p.region = "us-east-1"
if p.availabilityZone == "" {
p.availabilityZone = "us-east-1a"
}
}
if p.os == "" {
p.os = oshelp.OSLinux
}
if p.arch == "" {
p.arch = "amd64"
}
// set default instance type if not provided
if p.size == "" {
if p.arch == "arm64" {
p.size = "a1.medium"
} else {
p.size = "t3.nano"
}
}
// set the default disk size if not provided
if p.volumeSize == 0 {
p.volumeSize = 32
Expand All @@ -103,14 +69,6 @@ func New(opts ...Option) (drivers.Pool, error) {
if p.deviceName == "" {
p.deviceName = "/dev/sda1"
}
// set the default ssh user. this user account is responsible for executing the pipeline script.
if p.user == "" {
if p.os == oshelp.OSWindows {
p.user = "Administrator"
} else {
p.user = "root"
}
}
// setup service if not provided
if p.service == nil {
config := &aws.Config{
Expand Down
Loading

0 comments on commit f687dbf

Please sign in to comment.