Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(task) refactor pool/driver struct #94

Merged
merged 2 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions command/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type (
MarketType string `json:"market_type,omitempty" yaml:"market_type"`
RootDirectory string `json:"root_directory,omitempty" yaml:"root_directory"`
Hibernate bool `json:"hibernate,omitempty"`
User string `json:"user,omitempty"`
}

VMFusion struct {
Expand Down Expand Up @@ -122,9 +123,9 @@ func (s *Instance) UnmarshalJSON(data []byte) error {
return err
}
switch s.Type {
case "amazon":
case "amazon", "aws":
s.Spec = new(Amazon)
case "gcp":
case "google", "gcp":
s.Spec = new(Google)
case "vmfusion":
s.Spec = new(VMFusion)
Expand Down
9 changes: 5 additions & 4 deletions engine/linter/linter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ func Test_checkPools(t *testing.T) {
}

func DummyPool(name, runnerName string) drivers.Pool {
var pool, err = amazon.New(
amazon.WithRunnerName(runnerName),
amazon.WithName(name), // pool name
)
var pool drivers.Pool
pool.Name = name
pool.RunnerName = runnerName
var driver, err = amazon.New()
pool.Driver = driver
if err != nil {
return pool
}
Expand Down
9 changes: 5 additions & 4 deletions internal/certs/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ func Generate(runnerName string) (*types.InstanceCreateOpts, error) {
return nil, fmt.Errorf("failed to generate tls certificate: %w", err)
}
return &types.InstanceCreateOpts{
CACert: ca.Cert,
CAKey: ca.Key,
TLSCert: tlsCert.Cert,
TLSKey: tlsCert.Key,
CACert: ca.Cert,
CAKey: ca.Key,
TLSCert: tlsCert.Cert,
TLSKey: tlsCert.Key,
RunnerName: runnerName,
}, nil
}
66 changes: 25 additions & 41 deletions internal/drivers/amazon/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,13 @@ package amazon
import (
"os"

"github.com/drone-runners/drone-runner-aws/oshelp"

"github.com/sirupsen/logrus"
)

type Option func(*provider)

// WithRunnerName returns an option to set the runner name
func WithRunnerName(name string) Option {
return func(p *provider) {
p.runnerName = name
}
}

func WithOs(machineOs string) Option {
return func(p *provider) {
p.os = machineOs
}
}

func WithArch(arch string) Option {
return func(p *provider) {
p.arch = arch
}
}

func WithAccessKeyID(accessKeyID string) Option {
return func(p *provider) {
p.accessKeyID = accessKeyID
Expand Down Expand Up @@ -90,9 +73,17 @@ func WithSecurityGroup(group ...string) Option {
}

// WithSize returns an option to set the instance size.
func WithSize(size string) Option {
func WithSize(size, arch string) Option {
return func(p *provider) {
p.size = size
// set default instance type if not provided
if p.size == "" {
if arch == "arm64" {
p.size = "a1.medium"
} else {
p.size = "t3.nano"
}
}
}
}

Expand Down Expand Up @@ -163,27 +154,6 @@ func WithMarketType(t string) Option {
}
}

// WithName returns an option to set the instance name.
func WithName(name string) Option {
return func(p *provider) {
p.name = name
}
}

// WithLimit the total number of running servers. If exceeded block or error.
func WithLimit(limit int) Option {
return func(p *provider) {
p.limit = limit
}
}

// WithPool total number of warm instances in the pool at all times
func WithPool(pool int) Option {
return func(p *provider) {
p.pool = pool
}
}

// WithZone returns an option to set the zone.
func WithZone(zone string) Option {
return func(p *provider) {
Expand All @@ -203,3 +173,17 @@ func WithHibernate(hibernate bool) Option {
p.hibernate = hibernate
}
}

func WithUser(user, platform string) Option {
return func(p *provider) {
p.user = user
// set the default ssh user. this user account is responsible for executing the pipeline script.
if p.user == "" {
if platform == oshelp.OSWindows {
p.user = "Administrator"
} else {
p.user = "root"
}
}
}
}
48 changes: 16 additions & 32 deletions internal/drivers/amazon/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,24 @@ import (
"github.com/cenkalti/backoff/v4"
)

func (p *provider) GetProviderName() string {
func (p *provider) ProviderName() string {
return string(types.ProviderAmazon)
}

func (p *provider) GetName() string {
return p.name
}

func (p *provider) GetInstanceType() string {
func (p *provider) InstanceType() string {
return p.image
}

func (p *provider) GetOS() string {
return p.os
}

func (p *provider) GetRootDir() string {
func (p *provider) RootDir() string {
return p.rootDir
}

func (p *provider) GetMaxSize() int {
return p.limit
}

func (p *provider) GetMinSize() int {
return p.pool
}

func (p *provider) CanHibernate() bool {
return p.hibernate
}

// PingProvider checks that we can log into EC2, and the regions respond
func (p *provider) PingProvider(ctx context.Context) error {
// Ping checks that we can log into EC2, and the regions respond
func (p *provider) Ping(ctx context.Context) error {
client := p.service

allRegions := true
Expand All @@ -67,12 +51,12 @@ func (p *provider) Create(ctx context.Context, opts *types.InstanceCreateOpts) (

logr := logger.FromContext(ctx).
WithField("provider", types.ProviderAmazon).
WithField("ami", p.GetInstanceType()).
WithField("pool", p.name).
WithField("ami", p.InstanceType()).
WithField("pool", opts.PoolName).
WithField("region", p.region).
WithField("image", p.image).
WithField("size", p.size)
var name = fmt.Sprintf(p.runnerName+"-"+p.name+"-%d", time.Now().Unix())
var name = fmt.Sprintf(opts.RunnerName+"-"+opts.PoolName+"-%d", time.Now().Unix())

var tags = map[string]string{
"Name": name,
Expand All @@ -98,7 +82,7 @@ func (p *provider) Create(ctx context.Context, opts *types.InstanceCreateOpts) (
IamInstanceProfile: iamProfile,
UserData: aws.String(
base64.StdEncoding.EncodeToString(
[]byte(userdata.Generate(p.userData, p.os, p.arch, opts)),
[]byte(userdata.Generate(p.userData, opts)),
),
),
NetworkInterfaces: []*ec2.InstanceNetworkInterfaceSpecification{
Expand Down Expand Up @@ -182,13 +166,13 @@ func (p *provider) Create(ctx context.Context, opts *types.InstanceCreateOpts) (
Name: instanceID,
Provider: types.ProviderAmazon,
State: types.StateCreated,
Pool: p.name,
Pool: opts.PoolName,
Image: p.image,
Zone: p.availabilityZone,
Region: p.region,
Size: p.size,
Platform: p.os,
Arch: p.arch,
Platform: opts.OS,
Arch: opts.Arch,
Address: instanceIP,
CACert: opts.CACert,
CAKey: opts.CAKey,
Expand Down Expand Up @@ -234,10 +218,10 @@ func (p *provider) Destroy(ctx context.Context, instanceIDs ...string) (err erro
return
}

func (p *provider) Hibernate(ctx context.Context, instanceID string) error {
func (p *provider) Hibernate(ctx context.Context, instanceID, poolName string) error {
logr := logger.FromContext(ctx).
WithField("provider", types.ProviderAmazon).
WithField("pool", p.name).
WithField("pool", poolName).
WithField("instanceID", instanceID)

client := p.service
Expand All @@ -255,12 +239,12 @@ func (p *provider) Hibernate(ctx context.Context, instanceID string) error {
return nil
}

func (p *provider) Start(ctx context.Context, instanceID string) (string, error) {
func (p *provider) Start(ctx context.Context, instanceID, poolName string) (string, error) {
client := p.service

logr := logger.FromContext(ctx).
WithField("provider", types.ProviderAmazon).
WithField("pool", p.name).
WithField("pool", poolName).
WithField("instanceID", instanceID)

amazonInstance, err := p.getInstance(ctx, instanceID)
Expand Down
46 changes: 2 additions & 44 deletions internal/drivers/amazon/provider.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
package amazon

import (
"github.com/drone-runners/drone-runner-aws/internal/drivers"
"github.com/drone-runners/drone-runner-aws/oshelp"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/drone-runners/drone-runner-aws/internal/drivers"
)

// provider is a struct that implements drivers.Pool interface
type provider struct {
name string
runnerName string

spotInstance bool
region string
availabilityZone string
Expand All @@ -24,8 +19,6 @@ type provider struct {
secretAccessKey string
keyPairName string

os string
arch string
rootDir string

image string
Expand All @@ -43,50 +36,23 @@ type provider struct {
iamProfileArn string
hibernate bool

// pool size data
pool int
limit int

service *ec2.EC2
}

func New(opts ...Option) (drivers.Pool, error) {
func New(opts ...Option) (drivers.Driver, error) {
p := new(provider)
for _, opt := range opts {
opt(p)
}
if p.retries == 0 {
p.retries = 10
}
if p.pool < 0 {
p.pool = 0
}
if p.limit <= 0 {
p.limit = 100
}
if p.pool > p.limit {
p.limit = p.pool
}
if p.region == "" {
p.region = "us-east-1"
if p.availabilityZone == "" {
p.availabilityZone = "us-east-1a"
}
}
if p.os == "" {
p.os = oshelp.OSLinux
}
if p.arch == "" {
p.arch = "amd64"
}
// set default instance type if not provided
if p.size == "" {
if p.arch == "arm64" {
p.size = "a1.medium"
} else {
p.size = "t3.nano"
}
}
// set the default disk size if not provided
if p.volumeSize == 0 {
p.volumeSize = 32
Expand All @@ -103,14 +69,6 @@ func New(opts ...Option) (drivers.Pool, error) {
if p.deviceName == "" {
p.deviceName = "/dev/sda1"
}
// set the default ssh user. this user account is responsible for executing the pipeline script.
if p.user == "" {
if p.os == oshelp.OSWindows {
p.user = "Administrator"
} else {
p.user = "root"
}
}
// setup service if not provided
if p.service == nil {
config := &aws.Config{
Expand Down
Loading