From 4f7fce5852861cc4f39f55b132cbadaf2bea9146 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 28 Jun 2024 11:29:23 +0200 Subject: [PATCH 1/6] feat(swupdate): provide a subcommand to reboot the device --- .vscode/settings.json | 3 +- cmd/ovp8xx/cmd/swupdate.go | 55 +++++++++++++++++++++++++- pkg/swupdater/swupdater.go | 80 +++++++++++++++++++++++++++++++------- 3 files changed, 121 insertions(+), 17 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 7a53f8d..5645ae6 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "cSpell.words": [ - "PCIC" + "PCIC", + "swupdater" ] } \ No newline at end of file diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go index 0f1548f..cce8247 100644 --- a/cmd/ovp8xx/cmd/swupdate.go +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -94,9 +94,60 @@ The command establishes a connection to the device, uploads the firmware file, a RunE: swupdateCommand, } +// restartCmd represents the restart command. +// It restarts the device using the SWUpdater service. +// Depending on the state of the device, this command will either reboot to productive mode +// or restart the SWUpdater service again. +// If a previous update was initiated but not successful, the device will restart the SWUpdater service again. +var restartCmd = &cobra.Command{ + Use: "restart", + Short: "Restart the device", + Long: `This command restarts the device using the SWUpdater service. +Depending on the state of the device this will reboot to productive mode +or restart the SWUpdater service again. + +In case a previous update was initiated but not successful the device will restart +the SWUpdater service again. +`, + RunE: func(cmd *cobra.Command, args []string) error { + // Retrieve host and port from the parent command's flags + host, err := rootCmd.PersistentFlags().GetString("ip") + if err != nil { + return fmt.Errorf("cannot get host: %w", err) + } + + port, err := cmd.Parent().Flags().GetUint16("port") + if err != nil { + // If the port is not set on the parent, use a default value or handle the error + return fmt.Errorf("cannot get port: %w", err) + } + + connectionTimeout, err := cmd.Flags().GetDuration("online") + if err != nil { + return fmt.Errorf("cannot get timeout: %w", err) + } + + updater := swupdater.NewSWUpdater(host, port, nil) + + // Call the Restart method on the SWUpdater instance + if err := updater.Restart(connectionTimeout); err != nil { + return fmt.Errorf("failed to restart the device: %w", err) + } + + fmt.Println("Device restart initiated successfully.") + return nil + }, +} + func init() { rootCmd.AddCommand(swupdateCmd) - swupdateCmd.Flags().Uint16("port", 8080, "Port number for SWUpdate") - swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload") + + swupdateCmd.PersistentFlags().Uint16("port", 8080, "Port number for SWUpdate") + swupdateCmd.PersistentFlags().Duration("online", 2*time.Minute, "The time to wait for the device to become available") swupdateCmd.Flags().Duration("online", 2*time.Minute, "The time to wait for the device to become available") + swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload") + + // The restart sub command + swupdateCmd.AddCommand(restartCmd) + restartCmd.Flags().Duration("online", 3*time.Second, "The time to wait for the device to become available") } diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 6538533..7ce1d56 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "os" + "path/filepath" "strings" "time" @@ -40,7 +41,10 @@ func NewSWUpdater(hostName string, port uint16, notifications chan SWUpdaterNoti // The filename parameter specifies the name of the file to be uploaded. // Returns an error if the upload fails. func (s *SWUpdater) upload(filename string) error { - s.statusUpdate(fmt.Sprintf("Uploading software image to %s\n", s.urlUpload)) + basename := filepath.Base(filename) + s.statusUpdate( + fmt.Sprintf("Uploading software image: %s to %s\n", basename, s.urlUpload), + ) const fieldname string = "file" file, err := os.Open(filename) @@ -155,25 +159,17 @@ func (s *SWUpdater) statusUpdate(status string) { // It returns an error if the upload fails, or if the operation times out. func (s *SWUpdater) Update(filename string, connectionTimeout, timeout time.Duration) error { done := make(chan error) - start := time.Now() - s.statusUpdate("Waiting for the Device to become ready...") // Retry connection until successful or connectionTimeout occurs - for { - err := s.connect() - if err == nil { - s.statusUpdate("Device is ready now") - break - } - if time.Since(start) > connectionTimeout { - return fmt.Errorf("connection timeout: %w", err) - } - time.Sleep(3 * time.Second) // wait for a second before retrying + online, err := s.waitForOnline(connectionTimeout) + if !online { + return err } + // close the websocket after the Update operation defer s.disconnect() s.statusUpdate("Starting the Software Update process...") go s.waitForFinished(done) - err := s.upload(filename) + err = s.upload(filename) if err != nil { return fmt.Errorf("cannot upload software image: %w", err) } @@ -189,3 +185,59 @@ func (s *SWUpdater) Update(filename string, connectionTimeout, timeout time.Dura return errors.New("a timeout occurred while waiting for the update to finish") } } + +func (s *SWUpdater) waitForOnline(connectionTimeout time.Duration) (bool, error) { + start := time.Now() + s.statusUpdate("Waiting for the Device to become ready...") + + for { + err := s.connect() + if err == nil { + s.statusUpdate("Device is ready now") + break + } + if time.Since(start) > connectionTimeout { + return false, fmt.Errorf("connection timeout: %w", err) + } + // Retry after 3 seconds + time.Sleep(3 * time.Second) + } + return true, nil +} + +// Restart reboots the device by sending a POST request to the restart endpoint. +func (s *SWUpdater) Restart(timeout time.Duration) error { + // Construct the URL for the restart endpoint + restartURL := fmt.Sprintf("http://%s:%d/restart", s.hostName, s.port) + + online, err := s.waitForOnline(timeout) + if !online { + return err + } + // close the websocket after the Restart operation + defer s.disconnect() + + // Create a POST request with an empty body + req, err := http.NewRequest("POST", restartURL, nil) + if err != nil { + return fmt.Errorf("failed to create request (%s): %w", restartURL, err) + } + + // Send the request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send restart request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusServiceUnavailable { + return fmt.Errorf("the SWUpdate service is not available at the moment, please try again later") + } + + // Check if the response status code indicates success + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("restart request (%s) failed with status code: %d", restartURL, resp.StatusCode) + } + + return nil +} From 7d461ae5514a08f638bfdc1bd88fbd079e789997 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 5 Jul 2024 16:11:09 +0200 Subject: [PATCH 2/6] feat: add PCIC commands --- cmd/ovp8xx/cmd/pcic.go | 40 ++++++++++++++--- pkg/pcic/protocol.go | 95 ++++++++++++++++++++++++++++++++++++++- pkg/pcic/protocol_test.go | 5 +++ 3 files changed, 131 insertions(+), 9 deletions(-) diff --git a/cmd/ovp8xx/cmd/pcic.go b/cmd/ovp8xx/cmd/pcic.go index 5c0b6d7..6c756ad 100644 --- a/cmd/ovp8xx/cmd/pcic.go +++ b/cmd/ovp8xx/cmd/pcic.go @@ -5,6 +5,7 @@ package cmd import ( "fmt" + "sync" "github.com/graugans/go-ovp8xx/v2/pkg/pcic" "github.com/spf13/cobra" @@ -40,6 +41,10 @@ func (r *PCICReceiver) Notification(msg pcic.NotificationMessage) { fmt.Printf("Notification: %v\n", msg) } +func (r *PCICReceiver) CommandResponse(rsp pcic.Response) { + fmt.Printf("Command Response Ticket: %v Data: %s\n", rsp.Ticket, string(rsp.Data)) +} + // pcicCommand is a function that handles the execution of the "pcic" command. // It initializes a PCICReceiver, creates a helper, and establishes a connection to the PCIC client. // It then continuously processes incoming data using the PCIC client and the testHandler. @@ -48,6 +53,14 @@ func (r *PCICReceiver) Notification(msg pcic.NotificationMessage) { func pcicCommand(cmd *cobra.Command, args []string) error { var testHandler *PCICReceiver = &PCICReceiver{} var err error + + // Retrieve the slice of commands + cmds, err := cmd.Flags().GetStringSlice("cmd") + if err != nil { + // Handle the error + return err + } + helper, err := NewHelper(cmd) if err != nil { return err @@ -56,16 +69,28 @@ func pcicCommand(cmd *cobra.Command, args []string) error { pcic, err := pcic.NewPCICClient( pcic.WithTCPClient(helper.hostname(), helper.remotePort()), ) - if err != nil { - return err - } - for { - err = pcic.ProcessIncomming(testHandler) + var wg sync.WaitGroup + wg.Add(1) // We're going to wait for one goroutine + + go func() { + defer wg.Done() // This will be called when the goroutine finishes + for { + err = pcic.ProcessIncomming(testHandler) + if err != nil { + // An error occured, we break the loop + break + } + } + }() + // execute the commands + for _, cmd := range cmds { + _, err = pcic.Send([]byte(cmd)) if err != nil { - // An error occured, we break the loop - break + return fmt.Errorf("failed to send command: %w", err) } } + // Wait for the goroutine to finish before executing the commands + wg.Wait() return err } @@ -79,4 +104,5 @@ var pcicCmd = &cobra.Command{ func init() { rootCmd.AddCommand(pcicCmd) pcicCmd.Flags().Uint16("port", 50010, "The port to connect to") + pcicCmd.Flags().StringSlice("cmd", []string{}, "Commands to send to the device") } diff --git a/pkg/pcic/protocol.go b/pkg/pcic/protocol.go index 7ef1038..b96e1cd 100644 --- a/pkg/pcic/protocol.go +++ b/pkg/pcic/protocol.go @@ -6,7 +6,10 @@ import ( "errors" "fmt" "io" + "math/rand" "net" + "strconv" + "strings" ) type ( @@ -47,6 +50,7 @@ type MessageHandler interface { Result(Frame) Error(ErrorMessage) Notification(NotificationMessage) + CommandResponse(Response) } type NotificationMessage struct { @@ -59,6 +63,11 @@ type ErrorMessage struct { Message string } +type Response struct { + Ticket string + Data []byte +} + func NewPCICClient(options ...PCICClientOption) (*PCICClient, error) { var err error pcic := &PCICClient{} @@ -105,10 +114,12 @@ func (p *PCICClient) ProcessIncomming(handler MessageHandler) error { return err } firstTicket := header[:ticketFieldLength] + ticketStr := string(firstTicket) + secondTicket := header[secondTicketOffset:dataOffset] if !bytes.Equal(firstTicket, secondTicket) { return fmt.Errorf("mismatch in the tickets %s != %s ", - string(firstTicket), + string(ticketStr), string(secondTicket), ) } @@ -132,7 +143,21 @@ func (p *PCICClient) ProcessIncomming(handler MessageHandler) error { if !bytes.Equal(trailer, []byte{'\r', '\n'}) { return errors.New("invalid trailer detected") } - if bytes.Equal(resultTicket, firstTicket) { + var ticketNum = 0 + if ticketStr != "0000" { + ticketNum, err = strconv.Atoi(strings.TrimLeft(ticketStr, "0")) + if err != nil { + return fmt.Errorf("unable to convert the ticket number %s to an integer", ticketStr) + } + } + if ticketNum > 100 { + r, err := responseParser(ticketStr, data) + if err != nil { + return fmt.Errorf("unable to parse the response: %w", err) + } + handler.CommandResponse(r) + return nil + } else if bytes.Equal(resultTicket, firstTicket) { frame, err := asyncResultParser(data) handler.Result(frame) return err @@ -144,6 +169,66 @@ func (p *PCICClient) ProcessIncomming(handler MessageHandler) error { return fmt.Errorf("unknown ticket received: %s", string(firstTicket)) } +func (p *PCICClient) Send(data []byte) (uint16, error) { + var ticket uint16 + if p.writer == nil { + return ticket, errors.New("no bufio.Writer provided, please instantiate the object") + } + // Let's generate a random ticket number + ticket = uint16(rand.Intn(8999) + 1000) + if ticket < 100 || ticket > 9999 { + return ticket, fmt.Errorf( + "invalid ticket number: %d, needs to be in the range 100-9999", ticket, + ) + } + + // Create a new buffer to aggregate the message + var buf bytes.Buffer + var delimter = []byte("\r\n") + // Convert ticket to a 4-digit string and then to bytes + ticketBytes := []byte(fmt.Sprintf("%04d", ticket)) + buf.Write(ticketBytes) + // A Command message is composed like this + // CRLFCRLF + // is a 4-digit number in the range 100-9999 + // is a character string starting with an 'L' followed by 9 digits + // interpreted as a decimal value. The number is the length of data that follows + // is the actual data that is being sent + length := len(ticketBytes) /**/ + len(data) + 2 /*CRLF*/ + lengthStr := fmt.Sprintf("L%09d", length) + lengthBytes := []byte(lengthStr) + buf.Write(lengthBytes) + buf.Write(delimter) + buf.Write(ticketBytes) + buf.Write(data) + buf.Write(delimter) + + // Write the buffer to the underlying writer + _, err := p.writer.Write(buf.Bytes()) + if err != nil { + return ticket, fmt.Errorf("unable to write to the buffer: %w", err) + } + // This is necessary to flush the buffer to the underlying writer + // Otherwise, the data will not be sent over the network + err = p.writer.Flush() + if err != nil { + return ticket, fmt.Errorf("unable to flush to the buffer: %w", err) + } + + return ticket, nil +} + +func responseParser(ticket string, data []byte) (Response, error) { + var err error + res := Response{} + if len(data) <= delimiterFieldLength { + return res, fmt.Errorf("the data is too short to be a valid frame: %d", len(data)) + } + res.Ticket = ticket + res.Data = data[:len(data)-delimiterFieldLength] + return res, err +} + func errorParser(data []byte) (ErrorMessage, error) { var err error errorStatus := ErrorMessage{} @@ -162,7 +247,13 @@ func errorParser(data []byte) (ErrorMessage, error) { func asyncResultParser(data []byte) (Frame, error) { frame := Frame{} var err error + if len(data) <= delimiterFieldLength { + return frame, fmt.Errorf("the data is too short to be a valid frame: %d", len(data)) + } contentDecorated := data[:len(data)-delimiterFieldLength] + if len(contentDecorated)-len(endMarker) < 0 { + return frame, fmt.Errorf("the data is too short to be a valid frame: %d: content: %s", len(data), string(data)) + } content := contentDecorated[len(endMarker) : len(contentDecorated)-len(endMarker)] if len(content) == 0 { // no content is available diff --git a/pkg/pcic/protocol_test.go b/pkg/pcic/protocol_test.go index c3d3ef6..69dbd7e 100644 --- a/pkg/pcic/protocol_test.go +++ b/pkg/pcic/protocol_test.go @@ -23,6 +23,7 @@ type PCICAsyncReceiver struct { frame pcic.Frame notificationMsg pcic.NotificationMessage errorMsg pcic.ErrorMessage + response pcic.Response } func (r *PCICAsyncReceiver) Result(frame pcic.Frame) { @@ -37,6 +38,10 @@ func (r *PCICAsyncReceiver) Notification(msg pcic.NotificationMessage) { r.notificationMsg = msg } +func (r *PCICAsyncReceiver) CommandResponse(res pcic.Response) { + r.response = res +} + var testHandler *PCICAsyncReceiver = &PCICAsyncReceiver{} func TestMinimalReceive(t *testing.T) { From 1869b263f5606ba2a705f99dc52ad398abcc5394 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Fri, 5 Jul 2024 14:15:55 +0000 Subject: [PATCH 3/6] build: use the latest Go version in the Devcontainer --- .devcontainer/devcontainer.json | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 53d4242..3fea99a 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -3,7 +3,7 @@ { "name": "Go", // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile - "image": "ghcr.io/graugans/golang:latest" + "image": "mcr.microsoft.com/devcontainers/go:1.22", // Features to add to the dev container. More info: https://containers.dev/features. // "features": {}, @@ -23,9 +23,7 @@ "golang.go" ] } - }, - - + } // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. // "remoteUser": "root" } From 823990febfd30e2c9545cec22ff450ededd556f4 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Sat, 6 Jul 2024 10:35:49 +0000 Subject: [PATCH 4/6] fix: synchronize the command with the response --- cmd/ovp8xx/cmd/pcic.go | 39 ++++++++++++++++---- pkg/pcic/protocol.go | 78 +++++++++++++++++++++++++++++---------- pkg/pcic/protocol_test.go | 5 --- 3 files changed, 89 insertions(+), 33 deletions(-) diff --git a/cmd/ovp8xx/cmd/pcic.go b/cmd/ovp8xx/cmd/pcic.go index 6c756ad..db98fd5 100644 --- a/cmd/ovp8xx/cmd/pcic.go +++ b/cmd/ovp8xx/cmd/pcic.go @@ -4,8 +4,11 @@ Copyright © 2024 Christian Ege package cmd import ( + "context" "fmt" + "strconv" "sync" + "time" "github.com/graugans/go-ovp8xx/v2/pkg/pcic" "github.com/spf13/cobra" @@ -41,10 +44,6 @@ func (r *PCICReceiver) Notification(msg pcic.NotificationMessage) { fmt.Printf("Notification: %v\n", msg) } -func (r *PCICReceiver) CommandResponse(rsp pcic.Response) { - fmt.Printf("Command Response Ticket: %v Data: %s\n", rsp.Ticket, string(rsp.Data)) -} - // pcicCommand is a function that handles the execution of the "pcic" command. // It initializes a PCICReceiver, creates a helper, and establishes a connection to the PCIC client. // It then continuously processes incoming data using the PCIC client and the testHandler. @@ -82,14 +81,38 @@ func pcicCommand(cmd *cobra.Command, args []string) error { } } }() + // execute the commands for _, cmd := range cmds { - _, err = pcic.Send([]byte(cmd)) + prefix := fmt.Sprintf(" %s # ", cmd) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + response, err := pcic.Send(ctx, []byte(cmd)) if err != nil { - return fmt.Errorf("failed to send command: %w", err) + cancel() + return fmt.Errorf("failed to send command: %v", err) + + } + if len(response) >= 9 { // Ensure there are at least 9 bytes + lengthStr := string(response[:9]) // Convert the first 9 bytes to a string + length, err := strconv.Atoi(lengthStr) // Convert the string to an integer + if err != nil { + // Response does not start with the length, print the whole response + fmt.Println(prefix, (response)) + } else { + if len(response) >= 9+length { + // Strip the first 9 bytes and print the rest up to the specified length + fmt.Println(string(response[9 : 9+length])) + } else { + cancel() + return fmt.Errorf("response too short: %s", string(response)) + } + } + } else { + fmt.Println(prefix, string(response)) } + cancel() } - // Wait for the goroutine to finish before executing the commands + // Wait for the goroutine to be finished wg.Wait() return err } @@ -104,5 +127,5 @@ var pcicCmd = &cobra.Command{ func init() { rootCmd.AddCommand(pcicCmd) pcicCmd.Flags().Uint16("port", 50010, "The port to connect to") - pcicCmd.Flags().StringSlice("cmd", []string{}, "Commands to send to the device") + pcicCmd.Flags().StringSlice("cmd", []string{}, "Commands to be send to the device, can be specified multiple times. All commands will be executed in order") } diff --git a/pkg/pcic/protocol.go b/pkg/pcic/protocol.go index b96e1cd..9a2cc29 100644 --- a/pkg/pcic/protocol.go +++ b/pkg/pcic/protocol.go @@ -3,6 +3,7 @@ package pcic import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -10,12 +11,17 @@ import ( "net" "strconv" "strings" + "sync" ) type ( PCICClient struct { - reader *bufio.Reader - writer *bufio.Writer + reader *bufio.Reader + writer *bufio.Writer + responseChans map[string]chan Response + ticketSet map[string]struct{} + + mu sync.Mutex } PCICClientOption func(c *PCICClient) error ) @@ -50,7 +56,6 @@ type MessageHandler interface { Result(Frame) Error(ErrorMessage) Notification(NotificationMessage) - CommandResponse(Response) } type NotificationMessage struct { @@ -103,6 +108,19 @@ func WithTCPClient(hostname string, port uint16) PCICClientOption { } } +func (p *PCICClient) generateTicket() string { + for { + ticket := fmt.Sprintf("%04d", rand.Intn(8999)+1000) + p.mu.Lock() + if _, exists := p.ticketSet[ticket]; !exists { + p.ticketSet[ticket] = struct{}{} + p.mu.Unlock() + return ticket + } + p.mu.Unlock() + } +} + func (p *PCICClient) ProcessIncomming(handler MessageHandler) error { reader := p.reader if reader == nil { @@ -151,11 +169,10 @@ func (p *PCICClient) ProcessIncomming(handler MessageHandler) error { } } if ticketNum > 100 { - r, err := responseParser(ticketStr, data) + err := p.responseParser(ticketStr, data) if err != nil { return fmt.Errorf("unable to parse the response: %w", err) } - handler.CommandResponse(r) return nil } else if bytes.Equal(resultTicket, firstTicket) { frame, err := asyncResultParser(data) @@ -169,18 +186,25 @@ func (p *PCICClient) ProcessIncomming(handler MessageHandler) error { return fmt.Errorf("unknown ticket received: %s", string(firstTicket)) } -func (p *PCICClient) Send(data []byte) (uint16, error) { - var ticket uint16 +func (p *PCICClient) Send(ctx context.Context, data []byte) ([]byte, error) { + var res []byte if p.writer == nil { - return ticket, errors.New("no bufio.Writer provided, please instantiate the object") + return res, errors.New("no bufio.Writer provided, please instantiate the object") } // Let's generate a random ticket number - ticket = uint16(rand.Intn(8999) + 1000) - if ticket < 100 || ticket > 9999 { - return ticket, fmt.Errorf( - "invalid ticket number: %d, needs to be in the range 100-9999", ticket, - ) - } + ticket := p.generateTicket() + + respChan := make(chan Response) + p.mu.Lock() + p.responseChans[ticket] = respChan + p.mu.Unlock() + + defer func() { + p.mu.Lock() + delete(p.responseChans, ticket) + delete(p.ticketSet, ticket) + p.mu.Unlock() + }() // Create a new buffer to aggregate the message var buf bytes.Buffer @@ -206,27 +230,41 @@ func (p *PCICClient) Send(data []byte) (uint16, error) { // Write the buffer to the underlying writer _, err := p.writer.Write(buf.Bytes()) if err != nil { - return ticket, fmt.Errorf("unable to write to the buffer: %w", err) + return res, fmt.Errorf("unable to write to the buffer: %w", err) } // This is necessary to flush the buffer to the underlying writer // Otherwise, the data will not be sent over the network err = p.writer.Flush() if err != nil { - return ticket, fmt.Errorf("unable to flush to the buffer: %w", err) + return res, fmt.Errorf("unable to flush to the buffer: %w", err) + } + + // Wait for the response or timeout + select { + case resp := <-respChan: + res = resp.Data + case <-ctx.Done(): + return res, fmt.Errorf("request with ticket: %s timed out or canceled", ticket) } - return ticket, nil + return res, nil } -func responseParser(ticket string, data []byte) (Response, error) { +func (p *PCICClient) responseParser(ticket string, data []byte) error { var err error res := Response{} if len(data) <= delimiterFieldLength { - return res, fmt.Errorf("the data is too short to be a valid frame: %d", len(data)) + return fmt.Errorf("the data is too short to be a valid frame: %d", len(data)) } res.Ticket = ticket res.Data = data[:len(data)-delimiterFieldLength] - return res, err + + p.mu.Lock() + if ch, ok := p.responseChans[ticket]; ok { + ch <- res + } + p.mu.Unlock() + return err } func errorParser(data []byte) (ErrorMessage, error) { diff --git a/pkg/pcic/protocol_test.go b/pkg/pcic/protocol_test.go index 69dbd7e..c3d3ef6 100644 --- a/pkg/pcic/protocol_test.go +++ b/pkg/pcic/protocol_test.go @@ -23,7 +23,6 @@ type PCICAsyncReceiver struct { frame pcic.Frame notificationMsg pcic.NotificationMessage errorMsg pcic.ErrorMessage - response pcic.Response } func (r *PCICAsyncReceiver) Result(frame pcic.Frame) { @@ -38,10 +37,6 @@ func (r *PCICAsyncReceiver) Notification(msg pcic.NotificationMessage) { r.notificationMsg = msg } -func (r *PCICAsyncReceiver) CommandResponse(res pcic.Response) { - r.response = res -} - var testHandler *PCICAsyncReceiver = &PCICAsyncReceiver{} func TestMinimalReceive(t *testing.T) { From ae7034e8779578fe6f1bc2c1f2502f230e529b74 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Sat, 6 Jul 2024 10:56:32 +0000 Subject: [PATCH 5/6] refactor: get rid of the ticketSet --- pkg/pcic/protocol.go | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/pkg/pcic/protocol.go b/pkg/pcic/protocol.go index 9a2cc29..d28e48c 100644 --- a/pkg/pcic/protocol.go +++ b/pkg/pcic/protocol.go @@ -19,9 +19,7 @@ type ( reader *bufio.Reader writer *bufio.Writer responseChans map[string]chan Response - ticketSet map[string]struct{} - - mu sync.Mutex + mu sync.Mutex } PCICClientOption func(c *PCICClient) error ) @@ -108,12 +106,16 @@ func WithTCPClient(hostname string, port uint16) PCICClientOption { } } -func (p *PCICClient) generateTicket() string { +// generateTicket generates a unique ticket and associates it with the provided reply channel. +// It uses a random number generator to generate a 4-digit ticket number. +// If the generated ticket already exists in the responseChans map, it continues generating a new ticket until a unique one is found. +// Once a unique ticket is found, it adds the ticket and reply channel to the responseChans map and returns the ticket. +func (p *PCICClient) generateTicket(replyChan chan Response) string { for { ticket := fmt.Sprintf("%04d", rand.Intn(8999)+1000) p.mu.Lock() - if _, exists := p.ticketSet[ticket]; !exists { - p.ticketSet[ticket] = struct{}{} + if _, exists := p.responseChans[ticket]; !exists { + p.responseChans[ticket] = replyChan p.mu.Unlock() return ticket } @@ -191,18 +193,13 @@ func (p *PCICClient) Send(ctx context.Context, data []byte) ([]byte, error) { if p.writer == nil { return res, errors.New("no bufio.Writer provided, please instantiate the object") } - // Let's generate a random ticket number - ticket := p.generateTicket() - respChan := make(chan Response) - p.mu.Lock() - p.responseChans[ticket] = respChan - p.mu.Unlock() + // Let's generate a random ticket number + ticket := p.generateTicket(respChan) defer func() { p.mu.Lock() delete(p.responseChans, ticket) - delete(p.ticketSet, ticket) p.mu.Unlock() }() @@ -210,7 +207,7 @@ func (p *PCICClient) Send(ctx context.Context, data []byte) ([]byte, error) { var buf bytes.Buffer var delimter = []byte("\r\n") // Convert ticket to a 4-digit string and then to bytes - ticketBytes := []byte(fmt.Sprintf("%04d", ticket)) + ticketBytes := []byte(ticket) buf.Write(ticketBytes) // A Command message is composed like this // CRLFCRLF From a17a4f58ba0bdef43f9794a2e7b5340b025451da Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Tue, 9 Jul 2024 14:07:29 +0000 Subject: [PATCH 6/6] fix: creation of the channel map --- pkg/pcic/protocol.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/pcic/protocol.go b/pkg/pcic/protocol.go index d28e48c..86b0825 100644 --- a/pkg/pcic/protocol.go +++ b/pkg/pcic/protocol.go @@ -73,7 +73,10 @@ type Response struct { func NewPCICClient(options ...PCICClientOption) (*PCICClient, error) { var err error - pcic := &PCICClient{} + pcic := &PCICClient{ + responseChans: make(map[string]chan Response), + } + // Apply options for _, opt := range options { if err = opt(pcic); err != nil {