Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions ip/aws/ip_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"cloudip/util"
"errors"
"fmt"
"log"
"net/http"
"os"
"strings"
Expand Down Expand Up @@ -128,29 +127,25 @@ func (ipDataManagerAws *IpDataManagerAws) EnsureDataFile() error {
return nil
}

func (ipDataManagerAws *IpDataManagerAws) LoadIpData() *IpRangeDataAws {
func (ipDataManagerAws *IpDataManagerAws) LoadIpData() (*IpRangeDataAws, error) {
if !ipDataManagerAws.IpRange.IsEmpty() {
return &ipDataManagerAws.IpRange
return &ipDataManagerAws.IpRange, nil
}

awsIpRangeData := IpRangeDataAws{}
ipDataFile, err := os.Open(ipDataManagerAws.DataFilePath)
if err != nil {
err = util.ErrorWithInfo(err, "error opening data file")
util.PrintErrorTrace(err)
log.Fatal(err)
return nil, util.ErrorWithInfo(err, "error opening data file")
}
defer ipDataFile.Close()

err = util.ReadJSON(ipDataFile, &awsIpRangeData)
if err != nil {
err = util.ErrorWithInfo(err, "error reading data file")
util.PrintErrorTrace(err)
log.Fatal(err)
return nil, util.ErrorWithInfo(err, "error reading data file")
}

ipDataManagerAws.IpRange = awsIpRangeData
return &ipDataManagerAws.IpRange
return &ipDataManagerAws.IpRange, nil
}

var ipDataManagerAws = &IpDataManagerAws{
Expand Down
17 changes: 17 additions & 0 deletions ip/aws/ip_data_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package aws

import (
"path/filepath"
"testing"
)

func TestAWSLoadIpDataReturnsErrorForMissingFile(t *testing.T) {
dir := t.TempDir()
manager := &IpDataManagerAws{
DataFilePath: filepath.Join(dir, "missing.json"),
}

if _, err := manager.LoadIpData(); err == nil {
t.Fatal("LoadIpData() error = nil, want error")
}
}
5 changes: 4 additions & 1 deletion ip/aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ type AWSProvider struct {
func NewAWSProvider() *AWSProvider {
return &AWSProvider{
BaseProvider: provider.NewBaseProvider("AWS", ipDataManagerAws, func(bp *provider.BaseProvider) error {
awsIpRangeData := *ipDataManagerAws.LoadIpData()
awsIpRangeData, err := ipDataManagerAws.LoadIpData()
if err != nil {
return err
}

for _, prefix := range awsIpRangeData.Prefixes {
bp.AddIPv4Range(prefix.IpPrefix)
Expand Down
15 changes: 5 additions & 10 deletions ip/azure/ip_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"cloudip/util"
"errors"
"fmt"
"log"
"os"
"sync"
"time"
Expand Down Expand Up @@ -142,29 +141,25 @@ func (ipDataManagerAzure *IpDataManagerAzure) EnsureDataFile() error {
return nil
}

func (ipDataManagerAzure *IpDataManagerAzure) LoadIpData() *IpRangeDataAzure {
func (ipDataManagerAzure *IpDataManagerAzure) LoadIpData() (*IpRangeDataAzure, error) {
if !ipDataManagerAzure.IpRange.IsEmpty() {
return &ipDataManagerAzure.IpRange
return &ipDataManagerAzure.IpRange, nil
}

azureIpRangeData := IpRangeDataAzure{}
ipDataFile, err := os.Open(ipDataManagerAzure.DataFilePath)
if err != nil {
err = util.ErrorWithInfo(err, "error loading data file")
util.PrintErrorTrace(err)
log.Fatal(err)
return nil, util.ErrorWithInfo(err, "error loading data file")
}
defer ipDataFile.Close()

err = util.ReadJSON(ipDataFile, &azureIpRangeData)
if err != nil {
err = util.ErrorWithInfo(err, "error reading data file")
util.PrintErrorTrace(err)
log.Fatal(err)
return nil, util.ErrorWithInfo(err, "error reading data file")
}

ipDataManagerAzure.IpRange = azureIpRangeData
return &ipDataManagerAzure.IpRange
return &ipDataManagerAzure.IpRange, nil
}

var ipDataManagerAzure = &IpDataManagerAzure{
Expand Down
12 changes: 12 additions & 0 deletions ip/azure/ip_data_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package azure

import (
"path/filepath"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -42,3 +43,14 @@ func TestAzureEnsureDataURIConcurrentAccess(t *testing.T) {
}
}
}

func TestAzureLoadIpDataReturnsErrorForMissingFile(t *testing.T) {
dir := t.TempDir()
manager := &IpDataManagerAzure{
DataFilePath: filepath.Join(dir, "missing.json"),
}

if _, err := manager.LoadIpData(); err == nil {
t.Fatal("LoadIpData() error = nil, want error")
}
}
5 changes: 4 additions & 1 deletion ip/azure/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ type AzureProvider struct {
func NewAzureProvider() *AzureProvider {
return &AzureProvider{
BaseProvider: provider.NewBaseProvider("Azure", ipDataManagerAzure, func(bp *provider.BaseProvider) error {
azureIpRangeData := *ipDataManagerAzure.LoadIpData()
azureIpRangeData, err := ipDataManagerAzure.LoadIpData()
if err != nil {
return err
}

for _, dataObject := range azureIpRangeData.Values {
for _, prefix := range dataObject.Properties.AddressPrefixes {
Expand Down
32 changes: 18 additions & 14 deletions ip/cloudflare/ip_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"cloudip/util"
"errors"
"io"
"log"
"os"
"strings"
)
Expand Down Expand Up @@ -125,34 +124,39 @@ func (m *IpDataManagerCloudflare) EnsureDataFile() error {
return nil
}

func (m *IpDataManagerCloudflare) LoadIpData() *IpRangeDataCloudflare {
func (m *IpDataManagerCloudflare) LoadIpData() (*IpRangeDataCloudflare, error) {
if !m.IpRange.IsEmpty() {
return &m.IpRange
return &m.IpRange, nil
}

v4CIDRs, err := readCIDRLines(m.DataFilePathV4)
if err != nil {
return nil, err
}
v6CIDRs, err := readCIDRLines(m.DataFilePathV6)
if err != nil {
return nil, err
}

data := IpRangeDataCloudflare{
V4CIDRs: readCIDRLines(m.DataFilePathV4),
V6CIDRs: readCIDRLines(m.DataFilePathV6),
V4CIDRs: v4CIDRs,
V6CIDRs: v6CIDRs,
}

m.IpRange = data
return &m.IpRange
return &m.IpRange, nil
}

func readCIDRLines(path string) []string {
func readCIDRLines(path string) ([]string, error) {
file, err := os.Open(path)
if err != nil {
err = util.ErrorWithInfo(err, "error opening data file")
util.PrintErrorTrace(err)
log.Fatal(err)
return nil, util.ErrorWithInfo(err, "error opening data file")
}
defer file.Close()

content, err := io.ReadAll(file)
if err != nil {
err = util.ErrorWithInfo(err, "error reading data file")
util.PrintErrorTrace(err)
log.Fatal(err)
return nil, util.ErrorWithInfo(err, "error reading data file")
}

lines := strings.Split(string(content), "\n")
Expand All @@ -164,7 +168,7 @@ func readCIDRLines(path string) []string {
}
cidrs = append(cidrs, trimmed)
}
return cidrs
return cidrs, nil
}

func writeCIDRLines(path string, cidrs []string) error {
Expand Down
17 changes: 16 additions & 1 deletion ip/cloudflare/ip_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ func TestCloudflareDownloadDataWritesCIDRsAndSignature(t *testing.T) {
t.Fatalf("metadata signature = %q, want %q", metadataManager.Metadata.Signature, "cf-etag")
}

data := manager.LoadIpData()
data, err := manager.LoadIpData()
if err != nil {
t.Fatalf("LoadIpData() error = %v", err)
}
if len(data.V4CIDRs) != 1 || data.V4CIDRs[0] != "173.245.48.0/20" {
t.Fatalf("V4CIDRs = %#v, want Cloudflare IPv4 CIDR", data.V4CIDRs)
}
Expand Down Expand Up @@ -132,3 +135,15 @@ func TestCloudflareEnsureDataFileReusesFetchedDataForUpdate(t *testing.T) {
t.Fatalf("metadata signature = %q, want %q", metadataManager.Metadata.Signature, "new-etag")
}
}

func TestCloudflareLoadIpDataReturnsErrorForMissingFile(t *testing.T) {
dir := t.TempDir()
manager := &IpDataManagerCloudflare{
DataFilePathV4: filepath.Join(dir, "missing-v4.txt"),
DataFilePathV6: filepath.Join(dir, "missing-v6.txt"),
}

if _, err := manager.LoadIpData(); err == nil {
t.Fatal("LoadIpData() error = nil, want error")
}
}
5 changes: 4 additions & 1 deletion ip/cloudflare/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ type CloudflareProvider struct {
func NewCloudflareProvider() *CloudflareProvider {
return &CloudflareProvider{
BaseProvider: provider.NewBaseProvider("Cloudflare", ipDataManagerCloudflare, func(bp *provider.BaseProvider) error {
data := *ipDataManagerCloudflare.LoadIpData()
data, err := ipDataManagerCloudflare.LoadIpData()
if err != nil {
return err
}

for _, cidr := range data.V4CIDRs {
bp.AddIPv4Range(cidr)
Expand Down
11 changes: 4 additions & 7 deletions ip/gcp/ip_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"cloudip/util"
"errors"
"fmt"
"log"
"os"
)

Expand Down Expand Up @@ -121,20 +120,18 @@ func (ipDataManagerGcp *IpDataManagerGcp) EnsureDataFile() error {
return nil
}

func (ipDataManagerGcp *IpDataManagerGcp) LoadIpData() *IpRangeDataGcp {
func (ipDataManagerGcp *IpDataManagerGcp) LoadIpData() (*IpRangeDataGcp, error) {
if !ipDataManagerGcp.IpRange.IsEmpty() {
return &ipDataManagerGcp.IpRange
return &ipDataManagerGcp.IpRange, nil
}

gcpIpRangeData, err := ipDataManagerGcp.readDataFile()
if err != nil {
util.PrintErrorTrace(util.ErrorWithInfo(err, "error opening data file"))
util.PrintErrorTrace(err)
log.Fatal(err)
return nil, err
}

ipDataManagerGcp.IpRange = *gcpIpRangeData
return &ipDataManagerGcp.IpRange
return &ipDataManagerGcp.IpRange, nil
}

func (ipDataManagerGcp *IpDataManagerGcp) readDataFile() (*IpRangeDataGcp, error) {
Expand Down
11 changes: 11 additions & 0 deletions ip/gcp/ip_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ func TestGCPEnsureDataFileReusesFetchedDataForUpdate(t *testing.T) {
t.Fatalf("metadata signature = %q, want %q", metadataManager.Metadata.Signature, "new-sync-token")
}
}

func TestGCPLoadIpDataReturnsErrorForMissingFile(t *testing.T) {
dir := t.TempDir()
manager := &IpDataManagerGcp{
DataFilePath: filepath.Join(dir, "missing.json"),
}

if _, err := manager.LoadIpData(); err == nil {
t.Fatal("LoadIpData() error = nil, want error")
}
}
5 changes: 4 additions & 1 deletion ip/gcp/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ type GCPProvider struct {
func NewGCPProvider() *GCPProvider {
return &GCPProvider{
BaseProvider: provider.NewBaseProvider("GCP", ipDataManagerGcp, func(bp *provider.BaseProvider) error {
gcpIpRangeData := *ipDataManagerGcp.LoadIpData()
gcpIpRangeData, err := ipDataManagerGcp.LoadIpData()
if err != nil {
return err
}

for _, prefix := range gcpIpRangeData.Prefixes {
if prefix.Ipv4Prefix != "" {
Expand Down