Browse Source

got the recursive func working and added some test cases

AETH-erial 10 months ago
parent
commit
d7bc9d87b4
3 changed files with 173 additions and 110 deletions
  1. 61 89
      pkg/local.go
  2. 112 20
      pkg/local_test.go
  3. 0 1
      test/slash16_ips.json

+ 61 - 89
pkg/local.go

@@ -5,13 +5,10 @@ import (
 	"log"
 	"net"
 	"net/netip"
+	"strconv"
 	"strings"
 )
 
-type AllAddress struct {
-	Addr []net.IP `json:"addresses"`
-}
-
 type NetworkInterfaceNotFound struct{ Passed string }
 
 // Implementing error interface
@@ -19,40 +16,69 @@ func (n *NetworkInterfaceNotFound) Error() string {
 	return fmt.Sprintf("Interface: '%s' not found.", n.Passed)
 }
 
-func getNextAddr(addr net.IP) (net.IP, error) {
-	next, err := netip.ParseAddr(addr.String())
+type IpSubnetMapper struct {
+	Ipv4s       []net.IP `json:"addresses"`
+	NetworkAddr net.IP
+	Current     net.IP
+	Mask        int
+}
+
+/*
+Get the next IPv4 address of the address specified in the 'addr' argument,
+
+	:param addr: the address to get the next address of
+*/
+func getNextAddr(addr string) string {
+	parsed, err := netip.ParseAddr(addr)
 	if err != nil {
-		return nil, err
+		log.Fatal("failed while parsing address in getNextAddr() ", err, "\n")
 	}
-	return net.ParseIP(next.Next().String()), nil
+	return parsed.Next().String()
+
+}
+
+/*
+get the network address of the ip address in 'addr' with the subnet mask from 'cidr'
+
+	    :param addr: the ipv4 address to get the network address of
+		:param cidr: the CIDR notation of the subbet
+*/
+func getNetwork(addr string, cidr int) string {
+	addr = fmt.Sprintf("%s/%v", addr, cidr)
+	ip, net, err := net.ParseCIDR(addr)
+	if err != nil {
+		log.Fatal("failed whilst attempting to parse cidr in getNetwork() ", err, "\n")
+	}
+	return ip.Mask(net.Mask).String()
+
 }
 
 /*
 Recursive function to get all of the IPv4 addresses for each IPv4 network that the host is on
 
-	:param addr: the address to recursively find the next address for
-	:param out: a pointer to a struct containing a list of addresses
+	     :param ipmap: a pointer to an IpSubnetMapper struct which contains domain details such as
+		               the subnet mask, the original network mask, and the current IP address used in the
+					   recursive function
+		:param max: This is safety feature to prevent stack overflows, so you can manually set the depth to
+		            call the function
 */
 func addressRecurse(ipmap *IpSubnetMapper, max int) {
 
 	if len(ipmap.Ipv4s) > max {
 		return
 	}
-	next, err := getNextAddr(ipmap.NetworkAddr)
-	if err != nil {
-		log.Println(err)
-		return
-	}
-	ip, net, err := net.ParseCIDR(next.String())
-	if err != nil {
-		log.Println(err)
-		return
-	}
-	if ip.Mask(net.Mask).String() != ipmap.NetworkAddr.String() {
+
+	next := getNextAddr(ipmap.Current.String())
+
+	nextNet := getNetwork(next, ipmap.Mask)
+	currentNet := ipmap.NetworkAddr.String()
+
+	if nextNet != currentNet {
 		return
 	}
+	ipmap.Current = net.ParseIP(next)
 
-	ipmap.Ipv4s = append(ipmap.Ipv4s, next)
+	ipmap.Ipv4s = append(ipmap.Ipv4s, net.ParseIP(next))
 	addressRecurse(ipmap, max)
 }
 
@@ -75,81 +101,27 @@ func getAddressByInterface(name string) ([]net.Addr, error) {
 }
 
 /*
-Utilized a recursive function to find all addresses in the address space that the host belongs.
-Returns a pointer to an AllAddresses struct who has a list of net.IP structs inside
-*/
-func GetAllAddresses(name string, maxDepth int) (*AllAddress, error) {
-	addresses, err := getAddressByInterface(name)
-	if err != nil {
-		return nil, err
-	}
-	out := &AllAddress{}
-	for idx := range addresses {
-		ip := net.ParseIP(strings.Split(addresses[idx].String(), "/")[0])
-		root, err := netip.ParseAddr(ip.Mask(ip.DefaultMask()).String())
-		if err != nil {
-			continue
-		}
-		if root.IsLoopback() {
-			continue
-		}
-		// addressRecurse(ip, ip, out, maxDepth)
-	}
-	return out, nil
-}
+Get all of the IPv4 addresses in the network that 'addr' belongs to. YOU MUST PASS THE ADDRESS WITH CIDR NOTATION
+i.e. '192.168.50.1/24'
 
-/*
-Utilized a recursive function to find all addresses in the address space that the host belongs.
-Returns a pointer to an AllAddresses struct who has a list of net.IP structs inside
+	:param addr: the ipv4 address to use for subnet discovery
 */
-func GetAllRemoteAddresses(addrs []string, maxDepth int) (*AllAddress, error) {
-	out := &AllAddress{}
-	var addresses []net.IP
-	for i := range addrs {
-		ip, _, err := net.ParseCIDR(addrs[i])
-		if err != nil {
-			return nil, err
-		}
-		addresses = append(addresses, ip)
-	}
-
-	for idx := range addresses {
-
-		ip := net.ParseIP(strings.Split(addresses[idx].String(), "/")[0])
-		root, err := netip.ParseAddr(ip.Mask(ip.DefaultMask()).String())
-		if err != nil {
-			continue
-		}
-		if root.IsLoopback() {
-			continue
-		}
-		// addressRecurse(ip, ip, out, maxDepth)
-
-	}
-	return out, nil
-}
-
-type IpSubnetMapper struct {
-	Ipv4s       []net.IP
-	NetworkAddr net.IP
-	Mask        net.IPMask
-}
-
-func RefactorGetAllRemAddr(addr string) (*AllAddress, error) {
-	//	out := &AllAddress{}
+func GetNetworkAddresses(addr string) (*IpSubnetMapper, error) {
 	ipmap := &IpSubnetMapper{Ipv4s: []net.IP{}}
+
 	ip, net, err := net.ParseCIDR(addr)
 	if err != nil {
 		return nil, err
 	}
+	mask, err := strconv.Atoi(strings.Split(addr, "/")[1])
+	if err != nil {
+		return nil, err
+	}
 	ipmap.NetworkAddr = ip.Mask(net.Mask)
-	ipmap.Mask = ip.DefaultMask()
-	fmt.Printf("%+v\n", ip.Mask(net.Mask))
-	fmt.Println(ip.DefaultMask())
-	fmt.Printf("%s\n", net.IP.DefaultMask())
-
-	addressRecurse(ipmap, 2000)
+	ipmap.Mask = mask
+	ipmap.Current = ip.Mask(net.Mask)
+	addressRecurse(ipmap, 65535)
 
-	return nil, nil
+	return ipmap, nil
 
 }

+ 112 - 20
pkg/local_test.go

@@ -2,58 +2,150 @@ package kyoketsu
 
 import (
 	"encoding/json"
+	"fmt"
 	"log"
-	"net/netip"
+	"net"
 	"os"
 	"testing"
-
-	"github.com/google/go-cmp/cmp"
 )
 
-func LoadTestAddresses(loc string) *AllAddress {
+type IpAddresses struct {
+	Addrs []string `json:"addresses"`
+}
+
+func LoadTestAddresses(loc string) map[string]struct{} {
 	b, err := os.ReadFile(loc)
 	if err != nil {
 		log.Fatal("Test setup failed.\n", err)
 	}
-	var alladdr AllAddress
-	err = json.Unmarshal(b, &alladdr)
+	var addr IpAddresses
+	addrmap := map[string]struct{}{}
+	err = json.Unmarshal(b, &addr)
 	if err != nil {
 		log.Fatal("test setup failed.\n", err)
 	}
-	return &alladdr
+	for i := range addr.Addrs {
+		addrmap[addr.Addrs[i]] = struct{}{}
+	}
+	return addrmap
 
 }
 
 // Testing the addres recursion function to return all IPs in the target address subnet
+// All test cases use a select sort to assert that all addresses in the test data are in the return
 func TestAddressRecurse(t *testing.T) {
 	type TestCase struct {
 		Name       string
-		Wants      *AllAddress
-		Input      string
+		TestData   string
+		InputAddr  string
+		InputMask  int
 		ShouldFail bool
 	}
 
 	tc := []TestCase{
 		TestCase{
-			Name:  "Passing testcase with valid IP address, returns all addresses.",
-			Wants: LoadTestAddresses("../test/local_ips.json"),
-			Input: "192.168.50.50",
+			Name:      "Passing testcase with valid IP address, returns all addresses.",
+			TestData:  "../test/local_ips.json",
+			InputAddr: "192.168.50.50",
+			InputMask: 24,
 		},
 		TestCase{
-			Name:  "Passing testcase with valid IP address that belongs to a /16 subnet",
-			Wants: LoadTestAddresses("../test/slash16_ips.json"),
-			Input: "10.252.1.0",
+			Name:      "Passing testcase with valid IP address that belongs to a /16 subnet",
+			TestData:  "../test/slash16_ips.json",
+			InputAddr: "10.252.1.1",
+			InputMask: 16,
 		},
 	}
 	for i := range tc {
-		addr, err := netip.ParseAddr(tc[i].Input)
+		addr, network, err := net.ParseCIDR(fmt.Sprintf("%s/%v", tc[i].InputAddr, tc[i].InputMask))
 		if err != nil {
 			t.Errorf("Test case: '%s' failed! Reason: %s", tc[i].Name, err)
 		}
-		got := &AllAddress{}
-		addressRecurse(addr, got, 65535)
-		if !cmp.Equal(got, tc[i].Wants) {
-			t.Errorf("Test case: '%s' failed! Got: %+v\nWant: %+v\n", tc[i].Name, got, tc[i].Wants)
+		got := &IpSubnetMapper{}
+		got.Mask = tc[i].InputMask
+		got.NetworkAddr = addr.Mask(network.Mask)
+		got.Current = addr.Mask(network.Mask)
+		addressRecurse(got, 65535)
+		want := LoadTestAddresses(tc[i].TestData)
+		for x := range got.Ipv4s {
+			gotip := got.Ipv4s[x]
+			_, ok := want[gotip.String()]
+			if !ok {
+				t.Errorf("Test '%s' failed! Address: %s was not found in the test data: %s\n", tc[i].Name, gotip.String(), tc[i].TestData)
+			}
+
+		}
+		log.Printf("Nice! Test: '%s' passed!\n", tc[i].Name)
+	}
+
+}
+
+// Testing the function to retrieve the next network address
+func TestGetNextAddr(t *testing.T) {
+	type TestCase struct {
+		Name       string
+		Input      string
+		Wants      string
+		ShouldFail bool
+	}
+
+	tc := []TestCase{
+		TestCase{
+			Name:       "Passing test case, function returns the next address",
+			Input:      "10.252.1.1",
+			Wants:      "10.252.1.2",
+			ShouldFail: false,
+		},
+		TestCase{
+			Name:       "Failing test case, function returns the wrong address",
+			Input:      "10.252.1.1",
+			Wants:      "10.252.1.4",
+			ShouldFail: true,
+		},
+	}
+	for i := range tc {
+		got := getNextAddr(tc[i].Input)
+		if got != tc[i].Wants {
+			if !tc[i].ShouldFail {
+				t.Errorf("Test: '%s' failed! Return: %s\nTest expected: %s\nTest Should fail: %v\n", tc[i].Name, got, tc[i].Wants, tc[i].ShouldFail)
+			}
+
+		}
+	}
+
+}
+
+func TestGetNetwork(t *testing.T) {
+	type TestCase struct {
+		Name       string
+		InputAddr  string
+		InputMask  int
+		Expects    string
+		ShouldFail bool
+	}
+	tc := []TestCase{
+		TestCase{
+			Name:       "Passing test, function returns the correct network given the CIDR mask",
+			InputAddr:  "192.168.50.35",
+			InputMask:  24,
+			Expects:    "192.168.50.0",
+			ShouldFail: false,
+		},
+		TestCase{
+			Name:       "Passing test, function returns the correct network given the CIDR mask (Larger network, /16 CIDR)",
+			InputAddr:  "10.252.47.200",
+			InputMask:  16,
+			Expects:    "10.252.0.0",
+			ShouldFail: false,
+		},
+	}
+	for i := range tc {
+		got := getNetwork(tc[i].InputAddr, tc[i].InputMask)
+		if got != tc[i].Expects {
+			if !tc[i].ShouldFail {
+				t.Errorf("Test: '%s' failed! Returned: %s\nExpected: %s\nShould fail: %v", tc[i].Name, got, tc[i].Expects, tc[i].ShouldFail)
+			}
+
 		}
 
 	}

File diff suppressed because it is too large
+ 0 - 1
test/slash16_ips.json


Some files were not shown because too many files changed in this diff