Skip to content

Commit

Permalink
Simplify server write loop and session validation
Browse files Browse the repository at this point in the history
  • Loading branch information
maeb committed Oct 8, 2021
1 parent c337e4a commit 96595fe
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 89 deletions.
68 changes: 29 additions & 39 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@ package server

import (
"fmt"
"github.com/nlnwa/gowarc"
"github.com/nlnwa/veidemann-contentwriter/database"
"github.com/nlnwa/veidemann-contentwriter/settings"
"google.golang.org/grpc/codes"
"io"
"net"

"github.com/nlnwa/gowarc"
"github.com/nlnwa/veidemann-api/go/contentwriter/v1"
"github.com/nlnwa/veidemann-contentwriter/database"
"github.com/nlnwa/veidemann-contentwriter/settings"
"github.com/nlnwa/veidemann-contentwriter/telemetry"
otgrpc "github.com/opentracing-contrib/go-grpc"
"github.com/opentracing/opentracing-go"
"github.com/rs/zerolog/log"
"google.golang.org/grpc"
"net"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type GrpcServer struct {
Expand Down Expand Up @@ -95,73 +96,62 @@ type ContentWriterService struct {
recordOptions []gowarc.WarcRecordOption
}

func (s *ContentWriterService) Write(stream contentwriter.ContentWriter_WriteServer) error {
func (s *ContentWriterService) Write(stream contentwriter.ContentWriter_WriteServer) (err error) {
telemetry.ScopechecksTotal.Inc()
//telemetry.ScopecheckResponseTotal.With(prometheus.Labels{"code": strconv.Itoa(int(result.ExcludeReason))}).Inc()
ctx := newWriteSessionContext(s.configCache, s.recordOptions)
defer ctx.cancelSession()
defer func() {
if err != nil {
log.Error().Err(err).Str("code", status.Code(err).String()).Msg("")
}
}()

for {
request, err := stream.Recv()
if err == io.EOF {
return s.onCompleted(ctx, stream)
break
}
if err != nil {
log.Err(err).Msgf("Error caught: %s", err.Error())
ctx.cancelSession(err.Error())
return err
}

switch v := request.Value.(type) {
case *contentwriter.WriteRequest_Meta:
log.Trace().Msgf("Got API request %T for %d records", v, len(v.Meta.RecordMeta))
if err := ctx.setWriteRequestMeta(v.Meta); err != nil {
ctx.cancelSession(err.Error())
return err
}
ctx.setWriteRequestMeta(v.Meta)
case *contentwriter.WriteRequest_ProtocolHeader:
log.Trace().Msgf("Got API request %T for record #%d. Size: %d", v, v.ProtocolHeader.RecordNum, len(v.ProtocolHeader.GetData()))
if err := ctx.writeProtocolHeader(v.ProtocolHeader); err != nil {
return err
return status.Errorf(codes.Unknown, "failed to write protocol header: %v", err)
}
case *contentwriter.WriteRequest_Payload:
log.Trace().Msgf("Got API request %T for record #%d. Size: %d", v, v.Payload.RecordNum, len(v.Payload.GetData()))
if err := ctx.writePayoad(v.Payload); err != nil {
return err
if err := ctx.writePayload(v.Payload); err != nil {
return status.Errorf(codes.Unknown, "failed to write payload: %v", err)
}
case *contentwriter.WriteRequest_Cancel:
log.Trace().Msgf("Got API request %T", v)
ctx.cancelSession(v.Cancel)
log.Debug().Str("reason", v.Cancel).Msg("Write request cancelled")
return stream.SendAndClose(new(contentwriter.WriteReply))
default:
return fmt.Errorf("Invalid request %s", v)
return status.Errorf(codes.InvalidArgument, "invalid write request: %v", v)
}
}
}

func (s *ContentWriterService) onCompleted(context *writeSessionContext, stream contentwriter.ContentWriter_WriteServer) error {
if context.canceled {
return context.handleErr(codes.Canceled, "Session canceled")
//return stream.SendAndClose(&contentwriter.WriteReply{})
}

if context.meta == nil {
return context.handleErr(codes.InvalidArgument, "Missing metadata object")
}

if err := context.validateSession(); err != nil {
context.cancelSession("Validation failed: " + err.Error())
return err
if err := ctx.validateSession(); err != nil {
return status.Errorf(codes.Unknown, "validation failed: %v", err)
}

records := make([]gowarc.WarcRecord, len(context.records))
records := make([]gowarc.WarcRecord, len(ctx.records))
for i := 0; i < len(records); i++ {
records[i] = context.records[int32(i)]
records[i] = ctx.records[int32(i)]
}
writer := s.warcWriterRegistry.GetWarcWriter(context.collectionConfig, context.meta.RecordMeta[0])
writeResponseMeta, err := writer.Write(context.meta, records...)
writer := s.warcWriterRegistry.GetWarcWriter(ctx.collectionConfig, ctx.meta.RecordMeta[0])
writeReply, err := writer.Write(ctx.meta, records...)
if err != nil {
context.cancelSession("Failed writing record: " + err.Error())
return err
return status.Errorf(codes.Unknown, "failed writing record(s): %v", err)
}

return stream.SendAndClose(writeResponseMeta)
return stream.SendAndClose(writeReply)
}
80 changes: 30 additions & 50 deletions server/sessioncontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,20 @@ package server

import (
"context"
"errors"
"fmt"
"sync"

"github.com/nlnwa/gowarc"
"github.com/nlnwa/veidemann-api/go/config/v1"
"github.com/nlnwa/veidemann-api/go/contentwriter/v1"
"github.com/nlnwa/veidemann-contentwriter/database"
"github.com/nlnwa/veidemann-contentwriter/settings"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"sync"
)

type writeSessionContext struct {
log zerolog.Logger
settings settings.Settings
configCache database.ConfigCache
meta *contentwriter.WriteRequestMeta
collectionConfig *config.ConfigObject
Expand All @@ -42,7 +40,6 @@ type writeSessionContext struct {
recordBuilders map[int32]gowarc.WarcRecordBuilder
payloadStarted map[int32]bool
rbMapSync sync.Mutex
canceled bool
}

func newWriteSessionContext(configCache database.ConfigCache, recordOpts []gowarc.WarcRecordOption) *writeSessionContext {
Expand All @@ -56,64 +53,31 @@ func newWriteSessionContext(configCache database.ConfigCache, recordOpts []gowar
}
}

func (s *writeSessionContext) handleErr(code codes.Code, msg string, args ...interface{}) error {
m := fmt.Sprintf(msg, args...)
s.log.Error().Msg(m)
return status.Error(code, m)
}

func (s *writeSessionContext) setWriteRequestMeta(w *contentwriter.WriteRequestMeta) error {
if s.meta == nil {
s.log = log.With().Str("eid", w.ExecutionId).Str("uri", w.TargetUri).Logger()
}
func (s *writeSessionContext) setWriteRequestMeta(w *contentwriter.WriteRequestMeta) {
s.meta = w

if w.CollectionRef == nil {
return s.handleErr(codes.InvalidArgument, "No collection id in request")
}
if w.IpAddress == "" {
return s.handleErr(codes.InvalidArgument, "Missing IP-address")
}

collectionConfig, err := s.configCache.GetConfigObject(context.TODO(), w.GetCollectionRef())
if err != nil {
msg := "Error getting collection config " + w.GetCollectionRef().GetId()
s.log.Error().Msg(msg)
return status.Error(codes.Unknown, msg)
}
s.collectionConfig = collectionConfig
if collectionConfig == nil || collectionConfig.Meta == nil || collectionConfig.Spec == nil {
return s.handleErr(codes.Unknown, "Collection with id '%s' is missing or insufficient: %s", w.CollectionRef.Id, collectionConfig.String())
}
return nil
}

func (s *writeSessionContext) writeProtocolHeader(header *contentwriter.Data) error {
recordBuilder := s.getRecordBuilder(header.RecordNum)
if recordBuilder.Size() != 0 {
err := s.handleErr(codes.InvalidArgument, "Header received twice")
s.cancelSession(err.Error())
return err
return errors.New("protocol header received twice")
}
if _, err := recordBuilder.Write(header.GetData()); err != nil {
s.cancelSession(err.Error())
return err
return fmt.Errorf("failed to write protocol header to the record builder: %w", err)
}
return nil
}

func (s *writeSessionContext) writePayoad(payload *contentwriter.Data) error {
func (s *writeSessionContext) writePayload(payload *contentwriter.Data) error {
recordBuilder := s.getRecordBuilder(payload.RecordNum)
if !s.payloadStarted[payload.RecordNum] {
if _, err := recordBuilder.Write([]byte("\r\n")); err != nil {
s.cancelSession(err.Error())
return err
return fmt.Errorf("failed to write pre-payload whitespace to the record builder: %w", err)
}
s.payloadStarted[payload.RecordNum] = true
}
if _, err := recordBuilder.Write(payload.GetData()); err != nil {
s.cancelSession(err.Error())
return err
return fmt.Errorf("failed to write payload for record number %d to the record builder: %w", payload.RecordNum, err)
}
return nil
}
Expand All @@ -133,10 +97,28 @@ func (s *writeSessionContext) getRecordBuilder(recordNum int32) gowarc.WarcRecor
}

func (s *writeSessionContext) validateSession() error {
if s.meta == nil {
return errors.New("missing metadata object")
}
if s.meta.CollectionRef == nil {
return errors.New("missing collection ref")
}
if s.meta.IpAddress == "" {
return errors.New("missing IP-address")
}
collectionConfig, err := s.configCache.GetConfigObject(context.TODO(), s.meta.CollectionRef)
if err != nil {
return fmt.Errorf("failed to get collection config: %s", s.meta.GetCollectionRef().GetId())
}
if collectionConfig == nil || collectionConfig.Meta == nil || collectionConfig.Spec == nil {
return fmt.Errorf("collection with id '%s' is missing or insufficient: %s", s.meta.GetCollectionRef().Id, collectionConfig)
}
s.collectionConfig = collectionConfig

for k, rb := range s.recordBuilders {
recordMeta, ok := s.meta.RecordMeta[k]
if !ok {
return s.handleErr(codes.InvalidArgument, "Missing metadata for record num: %d", k)
return fmt.Errorf("missing metadata for record num: %d", k)
}

rt := ToGowarcRecordType(recordMeta.Type)
Expand All @@ -156,16 +138,14 @@ func (s *writeSessionContext) validateSession() error {

wr, _, err := rb.Build()
if err != nil {
return s.handleErr(codes.InvalidArgument, "Error: %s", err)
return fmt.Errorf("failed to build record number %d: %w", k, err)
}
s.records[k] = wr
}
return nil
}

func (s *writeSessionContext) cancelSession(cancelReason string) {
s.canceled = true
s.log.Debug().Msgf("Request cancelled before WARC record written. Reason %s", cancelReason)
func (s *writeSessionContext) cancelSession() {
for _, rb := range s.recordBuilders {
_ = rb.Close()
}
Expand Down

0 comments on commit 96595fe

Please sign in to comment.