diff --git a/cmd/cli/add.go b/cmd/cli/add.go index cba5fa4..59f8834 100644 --- a/cmd/cli/add.go +++ b/cmd/cli/add.go @@ -9,23 +9,35 @@ import ( ) // Add prompts for the required information and creates a new peer -func Add(hostname string, privKey, pubKey bool, owner, description string, confirm bool) { +func Add(hostname string, privKey, pubKey bool, owner, description string, confirm bool) error { config, err := LoadConfigFile() - check(err, "failed to load configuration file") + if err != nil { + return fmt.Errorf("%w - failed to load configuration file", err) + } server := GetServer(config) var private, public string if privKey { - private = MustPromptString("private key", true) + if private, err = PromptString("private key", true); err != nil { + return err + } } if pubKey { - public = MustPromptString("public key", true) + if public, err = PromptString("public key", true); err != nil { + return err + } } if owner == "" { - owner = MustPromptString("owner", true) + owner, err = PromptString("owner", true) + if err != nil { + return fmt.Errorf("%w - invalid input for owner", err) + } } if description == "" { - description = MustPromptString("Description", true) + description, err = PromptString("Description", true) + if err != nil { + return fmt.Errorf("%w - invalid input for Description", err) + } } // publicKey := MustPromptString("PublicKey (optional)", false) @@ -37,22 +49,32 @@ func Add(hostname string, privKey, pubKey bool, owner, description string, confi fmt.Fprintln(os.Stderr) peer, err := lib.NewPeer(server, private, public, owner, hostname, description) - check(err, "failed to get new peer") + if err != nil { + return fmt.Errorf("%w - failed to get new peer", err) + } // TODO Some kind of recovery here would be nice, to avoid // leaving things in a potential broken state - config.MustAddPeer(peer) + if err = config.AddPeer(peer); err != nil { + return fmt.Errorf("%w - failed to add new peer", err) + } peerType := viper.GetString("output") peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server) - check(err, "failed to get peer configuration") + if err != nil { + return fmt.Errorf("%w - failed to get peer configuration", err) + } os.Stdout.Write(peerConfigBytes.Bytes()) - config.MustSave() + if err = config.Save(); err != nil { + return fmt.Errorf("%w - failed to save config file", err) + } server = GetServer(config) - err = server.ConfigureDevice() - check(err, "failed to configure device") + if err = server.ConfigureDevice(); err != nil { + return fmt.Errorf("%w - failed to configure device", err) + } + return nil } diff --git a/cmd/cli/config.go b/cmd/cli/config.go index 6e2d4ad..55a294c 100644 --- a/cmd/cli/config.go +++ b/cmd/cli/config.go @@ -106,12 +106,6 @@ func LoadConfigFile() (*DsnetConfig, error) { return &conf, nil } -func MustLoadConfigFile() *DsnetConfig { - config, err := LoadConfigFile() - check(err, "failed to load configuration file") - return config -} - // Save writes the configuration to disk func (conf *DsnetConfig) Save() error { configFile := viper.GetString("config_file") @@ -124,12 +118,6 @@ func (conf *DsnetConfig) Save() error { return nil } -// MustSave is like Save except it exits on error -func (conf *DsnetConfig) MustSave() { - err := conf.Save() - check(err, "failed to save config file") -} - // AddPeer adds a provided peer to the Peers list in the conf func (conf *DsnetConfig) AddPeer(peer lib.Peer) error { // TODO validate all PeerConfig (keys etc) @@ -169,12 +157,6 @@ func (conf *DsnetConfig) AddPeer(peer lib.Peer) error { return nil } -// MustAddPeer is like AddPeer, except it exist on error -func (conf *DsnetConfig) MustAddPeer(peer lib.Peer) { - err := conf.AddPeer(peer) - check(err) -} - // RemovePeer removes a peer from the peer list based on hostname func (conf *DsnetConfig) RemovePeer(hostname string) error { peerIndex := -1 @@ -195,12 +177,6 @@ func (conf *DsnetConfig) RemovePeer(hostname string) error { return nil } -// MustRemovePeer is like RemovePeer, except it exits on error -func (conf *DsnetConfig) MustRemovePeer(hostname string) { - err := conf.RemovePeer(hostname) - check(err) -} - func (conf DsnetConfig) GetWgPeerConfigs() []wgtypes.PeerConfig { wgPeers := make([]wgtypes.PeerConfig, 0, len(conf.Peers)) diff --git a/cmd/cli/init.go b/cmd/cli/init.go index 5cbde3f..cfd457b 100644 --- a/cmd/cli/init.go +++ b/cmd/cli/init.go @@ -15,7 +15,7 @@ import ( "github.com/spf13/viper" ) -func Init() { +func Init() error { listenPort := viper.GetInt("listen_port") configFile := viper.GetString("config_file") interfaceName := viper.GetString("interface_name") @@ -23,14 +23,23 @@ func Init() { _, err := os.Stat(configFile) if !os.IsNotExist(err) { - ExitFail("Refusing to overwrite existing %s", configFile) + return fmt.Errorf("%w - Refusing to overwrite existing %s", err, configFile) } privateKey, err := lib.GenerateJSONPrivateKey() - check(err, "failed to generate private key") + if err != nil { + return fmt.Errorf("%w - failed to generate private key", err) + } externalIPV4, err := getExternalIP() - check(err) + if err != nil { + return err + } + + externalIPV6, err := getExternalIP6() + if err != nil { + return err + } conf := &DsnetConfig{ PrivateKey: privateKey, @@ -40,7 +49,7 @@ func Init() { Peers: []PeerConfig{}, Domain: "dsnet", ExternalIP: externalIPV4, - ExternalIP6: getExternalIP6(), + ExternalIP6: externalIPV6, InterfaceName: interfaceName, Networks: []lib.JSONIPNet{}, PersistentKeepalive: 25, @@ -49,21 +58,28 @@ func Init() { server := GetServer(conf) ipv4, err := server.AllocateIP() - check(err, "failed to allocate ipv4 address") + if err != nil { + return fmt.Errorf("%w - failed to allocate ipv4 address", err) + } ipv6, err := server.AllocateIP6() - check(err, "failed to allocate ipv6 address") + if err != nil { + return fmt.Errorf("%w - failed to allocate ipv6 address", err) + } conf.IP = ipv4 conf.IP6 = ipv6 if len(conf.ExternalIP) == 0 && len(conf.ExternalIP6) == 0 { - ExitFail("Could not determine any external IP, v4 or v6") + return fmt.Errorf("Could not determine any external IP, v4 or v6") } - conf.MustSave() + if err := conf.Save(); err != nil { + return fmt.Errorf("%w - failed to save config file", err) + } fmt.Printf("Config written to %s. Please check/edit.\n", configFile) + return nil } // get a random IPv4 /22 subnet on 10.0.0.0 (1023 hosts) (or /24?) @@ -119,12 +135,16 @@ func getExternalIP() (net.IP, error) { Timeout: 5 * time.Second, } resp, err := client.Get("https://ipv4.icanhazip.com/") - check(err) + if err != nil { + return nil, err + } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { body, err := ioutil.ReadAll(resp.Body) - check(err) + if err != nil { + return nil, err + } IP = net.ParseIP(strings.TrimSpace(string(body))) return IP.To4(), nil } @@ -132,7 +152,7 @@ func getExternalIP() (net.IP, error) { return nil, errors.New("failed to determine external ip") } -func getExternalIP6() net.IP { +func getExternalIP6() (net.IP, error) { var IP net.IP conn, err := net.Dial("udp", "2001:4860:4860::8888:53") if err == nil { @@ -143,7 +163,7 @@ func getExternalIP6() net.IP { // check is not a ULA if IP[0] != 0xfd && IP[0] != 0xfc { - return IP + return IP, nil } } @@ -156,11 +176,13 @@ func getExternalIP6() net.IP { if resp.StatusCode == http.StatusOK { body, err := ioutil.ReadAll(resp.Body) - check(err) + if err != nil { + return nil, err + } IP = net.ParseIP(strings.TrimSpace(string(body))) - return IP + return IP, nil } } - return net.IP{} + return net.IP{}, nil } diff --git a/cmd/cli/regenerate.go b/cmd/cli/regenerate.go index b1c6676..8da8e17 100644 --- a/cmd/cli/regenerate.go +++ b/cmd/cli/regenerate.go @@ -8,8 +8,11 @@ import ( "github.com/spf13/viper" ) -func Regenerate(hostname string, confirm bool) { - config := MustLoadConfigFile() +func Regenerate(hostname string, confirm bool) error { + config, err := LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config file", err) + } server := GetServer(config) found := false @@ -21,36 +24,49 @@ func Regenerate(hostname string, confirm bool) { for _, peer := range server.Peers { if peer.Hostname == hostname { privateKey, err := lib.GenerateJSONPrivateKey() - check(err, "failed to generate private key") + if err != nil { + return fmt.Errorf("%w - failed to generate private key", err) + } preshareKey, err := lib.GenerateJSONKey() - check(err, "failed to generate preshared key") + if err != nil { + return fmt.Errorf("%w - failed to generate preshared key", err) + } peer.PrivateKey = privateKey peer.PublicKey = privateKey.PublicKey() peer.PresharedKey = preshareKey err = config.RemovePeer(hostname) - check(err, "failed to regenerate peer") + if err != nil { + return fmt.Errorf("%w - failed to regenerate peer", err) + } peerType := viper.GetString("output") peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server) - check(err, "failed to get peer configuration") + if err != nil { + return fmt.Errorf("%w - failed to get peer configuration", err) + } os.Stdout.Write(peerConfigBytes.Bytes()) found = true - config.MustAddPeer(peer) + if err = config.AddPeer(peer); err != nil { + return fmt.Errorf("%w - failure to add peer", err) + } break } } if !found { - ExitFail(fmt.Sprintf("unknown hostname: %s", hostname)) + return fmt.Errorf("unknown hostname: %s", hostname) } // Get a new server configuration so we can update the wg interface with the new peer details server = GetServer(config) - config.MustSave() + if err = config.Save(); err != nil { + return fmt.Errorf("%w - failure saving config", err) + } server.ConfigureDevice() + return nil } diff --git a/cmd/cli/remove.go b/cmd/cli/remove.go index 8eba3e9..d986b9f 100644 --- a/cmd/cli/remove.go +++ b/cmd/cli/remove.go @@ -2,19 +2,27 @@ package cli import "fmt" -func Remove(hostname string, confirm bool) { - conf := MustLoadConfigFile() +func Remove(hostname string, confirm bool) error { + conf, err := LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failed to load config", err) + } - err := conf.RemovePeer(hostname) - check(err, "failed to update config") + if err = conf.RemovePeer(hostname); err != nil { + return fmt.Errorf("%w - failed to update config", err) + } if !confirm { ConfirmOrAbort("Do you really want to remove %s?", hostname) } - conf.MustSave() + if err = conf.Save(); err != nil { + return fmt.Errorf("%w - failure to save config", err) + } server := GetServer(conf) - err = server.ConfigureDevice() - check(err, fmt.Sprintf("failed to sync server config to wg interface: %s", server.InterfaceName)) + if err = server.ConfigureDevice(); err != nil { + return fmt.Errorf("%w - failed to sync server config to wg interface: %s", err, server.InterfaceName) + } + return nil } diff --git a/cmd/cli/report.go b/cmd/cli/report.go index a687ff5..24e3743 100644 --- a/cmd/cli/report.go +++ b/cmd/cli/report.go @@ -69,24 +69,33 @@ type PeerReport struct { TransmitBytesSI string } -func GenerateReport() { - conf := MustLoadConfigFile() +func GenerateReport() error { + conf, err := LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config", err) + } wg, err := wgctrl.New() - check(err) + if err != nil { + return fmt.Errorf("%w - failure to create new client", err) + } defer wg.Close() dev, err := wg.Device(conf.InterfaceName) if err != nil { - ExitFail("Could not retrieve device '%s' (%v)", conf.InterfaceName, err) + return fmt.Errorf("%w - Could not retrieve device '%s'", err, conf.InterfaceName) } - report := GetReport(dev, conf) + report, err := GetReport(dev, conf) + if err != nil { + return err + } report.Print() + return nil } -func GetReport(dev *wgtypes.Device, conf *DsnetConfig) DsnetReport { +func GetReport(dev *wgtypes.Device, conf *DsnetConfig) (DsnetReport, error) { peerTimeout := viper.GetDuration("peer_timeout") peerExpiry := viper.GetDuration("peer_expiry") wgPeerIndex := make(map[wgtypes.Key]wgtypes.Peer) @@ -94,7 +103,9 @@ func GetReport(dev *wgtypes.Device, conf *DsnetConfig) DsnetReport { peersOnline := 0 linkDev, err := netlink.LinkByName(conf.InterfaceName) - check(err) + if err != nil { + return DsnetReport{}, fmt.Errorf("%w - error getting link", err) + } stats := linkDev.Attrs().Statistics @@ -165,7 +176,7 @@ func GetReport(dev *wgtypes.Device, conf *DsnetConfig) DsnetReport { ReceiveBytesSI: BytesToSI(stats.RxBytes), TransmitBytesSI: BytesToSI(stats.TxBytes), Timestamp: time.Now(), - } + }, nil } func (report *DsnetReport) Print() { diff --git a/cmd/cli/sync.go b/cmd/cli/sync.go index 50c5bc9..ff637c2 100644 --- a/cmd/cli/sync.go +++ b/cmd/cli/sync.go @@ -1,10 +1,17 @@ package cli -func Sync() { +import "fmt" + +func Sync() error { // TODO check device settings first conf, err := LoadConfigFile() - check(err, "failed to load configuration file") + if err != nil { + return fmt.Errorf("%w - failed to load configuration file", err) + } server := GetServer(conf) err = server.ConfigureDevice() - check(err, "failed to sync device configuration") + if err != nil { + return fmt.Errorf("%w - failed to sync device configuration", err) + } + return nil } diff --git a/cmd/cli/types.go b/cmd/cli/types.go index cb84166..dee25f4 100644 --- a/cmd/cli/types.go +++ b/cmd/cli/types.go @@ -27,22 +27,25 @@ func (k *JSONKey) UnmarshalJSON(b []byte) error { return err } -func GenerateJSONPrivateKey() JSONKey { +func GenerateJSONPrivateKey() (JSONKey, error) { privateKey, err := wgtypes.GeneratePrivateKey() - - check(err) + if err != nil { + return JSONKey{}, err + } return JSONKey{ Key: privateKey, - } + }, nil } -func GenerateJSONKey() JSONKey { +func GenerateJSONKey() (JSONKey, error) { privateKey, err := wgtypes.GenerateKey() - check(err) + if err != nil { + return JSONKey{}, err + } return JSONKey{ Key: privateKey, - } + }, err } diff --git a/cmd/cli/util.go b/cmd/cli/util.go index 461f3bb..4dea73c 100644 --- a/cmd/cli/util.go +++ b/cmd/cli/util.go @@ -1,5 +1,7 @@ package cli +// FIXME every function in this file has public scope, but only private references + import ( "bufio" "fmt" @@ -9,15 +11,6 @@ import ( "github.com/naggie/dsnet/lib" ) -func check(e error, optMsg ...string) { - if e != nil { - if len(optMsg) > 0 { - ExitFail("%s - %s", e, strings.Join(optMsg, " ")) - } - ExitFail("%s", e) - } -} - func jsonPeerToDsnetPeer(peers []PeerConfig) []lib.Peer { libPeers := make([]lib.Peer, 0, len(peers)) for _, p := range peers { @@ -37,12 +30,7 @@ func jsonPeerToDsnetPeer(peers []PeerConfig) []lib.Peer { return libPeers } -func ExitFail(format string, a ...interface{}) { - fmt.Fprintf(os.Stderr, "\033[31m"+format+"\033[0m\n", a...) - os.Exit(1) -} - -func MustPromptString(prompt string, required bool) string { +func PromptString(prompt string, required bool) (string, error) { reader := bufio.NewReader(os.Stdin) var text string var err error @@ -50,12 +38,15 @@ func MustPromptString(prompt string, required bool) string { for text == "" { fmt.Fprintf(os.Stderr, "%s: ", prompt) text, err = reader.ReadString('\n') - check(err) + if err != nil { + return "", fmt.Errorf("%w - error getting input", err) + } text = strings.TrimSpace(text) } - return text + return text, nil } +// FIXME is it critical for this to panic, or can we cascade the errors? func ConfirmOrAbort(format string, a ...interface{}) { fmt.Fprintf(os.Stderr, format+" [y/n] ", a...) @@ -69,7 +60,8 @@ func ConfirmOrAbort(format string, a ...interface{}) { if input == "y\n" { return } else { - ExitFail("Aborted.") + fmt.Fprintf(os.Stderr, "\033[31mAborted.\033[0m\n") + os.Exit(1) } } diff --git a/cmd/root.go b/cmd/root.go index 947110b..bd5b422 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -29,8 +29,8 @@ var ( "Create %s containing default configuration + new keys without loading. Edit to taste.", viper.GetString("config_file"), ), - Run: func(cmd *cobra.Command, args []string) { - cli.Init() + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Init() }, } @@ -38,7 +38,10 @@ var ( Use: "up", Short: "Create the interface, run pre/post up, sync", RunE: func(cmd *cobra.Command, args []string) error { - config := cli.MustLoadConfigFile() + config, err := cli.LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config file", err) + } server := cli.GetServer(config) if e := server.Up(); e != nil { return e @@ -54,7 +57,10 @@ var ( Use: "down", Short: "Destroy the interface, run pre/post down", RunE: func(cmd *cobra.Command, args []string) error { - config := cli.MustLoadConfigFile() + config, err := cli.LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config file", err) + } server := cli.GetServer(config) if e := server.DeleteLink(); e != nil { return e @@ -76,13 +82,16 @@ var ( } return nil }, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { privKey, err := cmd.PersistentFlags().GetBool("private-key") + if err != nil { + return err + } pubKey, err := cmd.PersistentFlags().GetBool("public-key") if err != nil { - cli.ExitFail("%w - error processing key flag", err) + return err } - cli.Add(args[0], privKey, pubKey, owner, description, confirm) + return cli.Add(args[0], privKey, pubKey, owner, description, confirm) }, } @@ -95,24 +104,24 @@ var ( } return nil }, - Run: func(cmd *cobra.Command, args []string) { - cli.Regenerate(args[0], confirm) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Regenerate(args[0], confirm) }, } syncCmd = &cobra.Command{ Use: "sync", Short: fmt.Sprintf("Update wireguard configuration from %s after validating", viper.GetString("config_file")), - Run: func(cmd *cobra.Command, args []string) { - cli.Sync() + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Sync() }, } reportCmd = &cobra.Command{ Use: "report", Short: "Generate a JSON status report to stdout", - Run: func(cmd *cobra.Command, args []string) { - cli.GenerateReport() + RunE: func(cmd *cobra.Command, args []string) error { + return cli.GenerateReport() }, } @@ -127,8 +136,8 @@ var ( return nil }, - Run: func(cmd *cobra.Command, args []string) { - cli.Remove(args[0], confirm) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Remove(args[0], confirm) }, } @@ -158,7 +167,8 @@ func init() { viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) if err := viper.BindPFlag("output", rootCmd.PersistentFlags().Lookup("output")); err != nil { - cli.ExitFail(err.Error()) + fmt.Fprintf(os.Stderr, "\033[31m%s\033[0m\n", err.Error()) + os.Exit(1) } viper.SetDefault("config_file", "/etc/dsnetconfig.json") @@ -186,7 +196,9 @@ func init() { func main() { if err := rootCmd.Execute(); err != nil { - cli.ExitFail(err.Error()) + // Because of side effects in viper, this gets printed twice + fmt.Fprintf(os.Stderr, "\033[31m%s\033[0m\n", err.Error()) + os.Exit(1) } os.Exit(0) }