diff --git a/cmd/list.go b/cmd/list.go index 4753fcd..db2339e 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog/log" "github.com/davegallant/vpngate/pkg/vpn" + "github.com/spf13/cobra" ) @@ -20,6 +21,7 @@ var listCmd = &cobra.Command{ Short: "List all available vpn servers", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { + vpnServers, err := vpn.GetList() if err != nil { log.Fatal().Msgf(err.Error()) diff --git a/pkg/util/retry.go b/pkg/util/retry.go new file mode 100644 index 0000000..0bb9a77 --- /dev/null +++ b/pkg/util/retry.go @@ -0,0 +1,18 @@ +package util + +import ( + "time" + "github.com/rs/zerolog/log" +) + +func Retry(attempts int, delay time.Duration,fn func() error) error { + var err error + for i := 0; i < attempts; i++ { + if err = fn(); err == nil { + return nil + } + log.Error().Msgf("Retrying after %d seconds. An error occured: %s", delay, err) + time.Sleep(delay) + } + return err +} diff --git a/pkg/vpn/list.go b/pkg/vpn/list.go index bca2296..0fe3bda 100644 --- a/pkg/vpn/list.go +++ b/pkg/vpn/list.go @@ -8,6 +8,7 @@ import ( "github.com/jszwec/csvutil" "github.com/rs/zerolog/log" + "github.com/davegallant/vpngate/pkg/util" "github.com/juju/errors" ) @@ -44,6 +45,7 @@ func parseVpnList(r io.Reader) (*[]Server, error) { // Trim known invalid rows serverList = bytes.TrimPrefix(serverList, []byte("*vpn_servers\r\n")) serverList = bytes.TrimSuffix(serverList, []byte("*\r\n")) + serverList = bytes.ReplaceAll(serverList, []byte(`"`), []byte{}) if err := csvutil.Unmarshal(serverList, &servers); err != nil { return nil, errors.Annotatef(err, "Unable to parse CSV") @@ -73,27 +75,36 @@ func GetList() (*[]Server, error) { log.Info().Msg("Fetching the latest server list") - r, err := http.Get(vpnList) - if err != nil { - return nil, errors.Annotate(err, "Unable to retrieve vpn list") - } + var r *http.Response - defer r.Body.Close() + err := util.Retry(5, 1, func() error { + var err error + r, err = http.Get(vpnList) + if err != nil { + return err + } + defer r.Body.Close() - if r.StatusCode != 200 { - return nil, errors.Annotatef(err, "Unexpected status code when retrieving vpn list: %d", r.StatusCode) - } + if r.StatusCode != 200 { + return errors.Annotatef(err, "Unexpected status code when retrieving vpn list: %d", r.StatusCode) + } - servers, err = parseVpnList(r.Body) + servers, err = parseVpnList(r.Body) + + if err != nil { + return err + } + + err = writeVpnListToCache(*servers) + + if err != nil { + log.Warn().Msgf("Unable to write servers to cache: %s", err) + } + return nil + }) if err != nil { - return nil, errors.Annotate(err, "unable to parse vpn list") - } - - err = writeVpnListToCache(*servers) - - if err != nil { - log.Warn().Msgf("Unable to write servers to cache: %s", err) + return nil, err } return servers, nil