Sfoglia il codice sorgente

got the recursive func working and added some test cases

AETH-erial 10 mesi fa
parent
commit
d7bc9d87b4
3 ha cambiato i file con 173 aggiunte e 110 eliminazioni
  1. 61 89
      pkg/local.go
  2. 112 20
      pkg/local_test.go
  3. 0 1
      test/slash16_ips.json

+ 61 - 89
pkg/local.go

@@ -5,13 +5,10 @@ import (
 	"log"
 	"log"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
+	"strconv"
 	"strings"
 	"strings"
 )
 )
 
 
-type AllAddress struct {
-	Addr []net.IP `json:"addresses"`
-}
-
 type NetworkInterfaceNotFound struct{ Passed string }
 type NetworkInterfaceNotFound struct{ Passed string }
 
 
 // Implementing error interface
 // Implementing error interface
@@ -19,40 +16,69 @@ func (n *NetworkInterfaceNotFound) Error() string {
 	return fmt.Sprintf("Interface: '%s' not found.", n.Passed)
 	return fmt.Sprintf("Interface: '%s' not found.", n.Passed)
 }
 }
 
 
-func getNextAddr(addr net.IP) (net.IP, error) {
-	next, err := netip.ParseAddr(addr.String())
+type IpSubnetMapper struct {
+	Ipv4s       []net.IP `json:"addresses"`
+	NetworkAddr net.IP
+	Current     net.IP
+	Mask        int
+}
+
+/*
+Get the next IPv4 address of the address specified in the 'addr' argument,
+
+	:param addr: the address to get the next address of
+*/
+func getNextAddr(addr string) string {
+	parsed, err := netip.ParseAddr(addr)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		log.Fatal("failed while parsing address in getNextAddr() ", err, "\n")
 	}
 	}
-	return net.ParseIP(next.Next().String()), nil
+	return parsed.Next().String()
+
+}
+
+/*
+get the network address of the ip address in 'addr' with the subnet mask from 'cidr'
+
+	    :param addr: the ipv4 address to get the network address of
+		:param cidr: the CIDR notation of the subbet
+*/
+func getNetwork(addr string, cidr int) string {
+	addr = fmt.Sprintf("%s/%v", addr, cidr)
+	ip, net, err := net.ParseCIDR(addr)
+	if err != nil {
+		log.Fatal("failed whilst attempting to parse cidr in getNetwork() ", err, "\n")
+	}
+	return ip.Mask(net.Mask).String()
+
 }
 }
 
 
 /*
 /*
 Recursive function to get all of the IPv4 addresses for each IPv4 network that the host is on
 Recursive function to get all of the IPv4 addresses for each IPv4 network that the host is on
 
 
-	:param addr: the address to recursively find the next address for
-	:param out: a pointer to a struct containing a list of addresses
+	     :param ipmap: a pointer to an IpSubnetMapper struct which contains domain details such as
+		               the subnet mask, the original network mask, and the current IP address used in the
+					   recursive function
+		:param max: This is safety feature to prevent stack overflows, so you can manually set the depth to
+		            call the function
 */
 */
 func addressRecurse(ipmap *IpSubnetMapper, max int) {
 func addressRecurse(ipmap *IpSubnetMapper, max int) {
 
 
 	if len(ipmap.Ipv4s) > max {
 	if len(ipmap.Ipv4s) > max {
 		return
 		return
 	}
 	}
-	next, err := getNextAddr(ipmap.NetworkAddr)
-	if err != nil {
-		log.Println(err)
-		return
-	}
-	ip, net, err := net.ParseCIDR(next.String())
-	if err != nil {
-		log.Println(err)
-		return
-	}
-	if ip.Mask(net.Mask).String() != ipmap.NetworkAddr.String() {
+
+	next := getNextAddr(ipmap.Current.String())
+
+	nextNet := getNetwork(next, ipmap.Mask)
+	currentNet := ipmap.NetworkAddr.String()
+
+	if nextNet != currentNet {
 		return
 		return
 	}
 	}
+	ipmap.Current = net.ParseIP(next)
 
 
-	ipmap.Ipv4s = append(ipmap.Ipv4s, next)
+	ipmap.Ipv4s = append(ipmap.Ipv4s, net.ParseIP(next))
 	addressRecurse(ipmap, max)
 	addressRecurse(ipmap, max)
 }
 }
 
 
@@ -75,81 +101,27 @@ func getAddressByInterface(name string) ([]net.Addr, error) {
 }
 }
 
 
 /*
 /*
-Utilized a recursive function to find all addresses in the address space that the host belongs.
-Returns a pointer to an AllAddresses struct who has a list of net.IP structs inside
-*/
-func GetAllAddresses(name string, maxDepth int) (*AllAddress, error) {
-	addresses, err := getAddressByInterface(name)
-	if err != nil {
-		return nil, err
-	}
-	out := &AllAddress{}
-	for idx := range addresses {
-		ip := net.ParseIP(strings.Split(addresses[idx].String(), "/")[0])
-		root, err := netip.ParseAddr(ip.Mask(ip.DefaultMask()).String())
-		if err != nil {
-			continue
-		}
-		if root.IsLoopback() {
-			continue
-		}
-		// addressRecurse(ip, ip, out, maxDepth)
-	}
-	return out, nil
-}
+Get all of the IPv4 addresses in the network that 'addr' belongs to. YOU MUST PASS THE ADDRESS WITH CIDR NOTATION
+i.e. '192.168.50.1/24'
 
 
-/*
-Utilized a recursive function to find all addresses in the address space that the host belongs.
-Returns a pointer to an AllAddresses struct who has a list of net.IP structs inside
+	:param addr: the ipv4 address to use for subnet discovery
 */
 */
-func GetAllRemoteAddresses(addrs []string, maxDepth int) (*AllAddress, error) {
-	out := &AllAddress{}
-	var addresses []net.IP
-	for i := range addrs {
-		ip, _, err := net.ParseCIDR(addrs[i])
-		if err != nil {
-			return nil, err
-		}
-		addresses = append(addresses, ip)
-	}
-
-	for idx := range addresses {
-
-		ip := net.ParseIP(strings.Split(addresses[idx].String(), "/")[0])
-		root, err := netip.ParseAddr(ip.Mask(ip.DefaultMask()).String())
-		if err != nil {
-			continue
-		}
-		if root.IsLoopback() {
-			continue
-		}
-		// addressRecurse(ip, ip, out, maxDepth)
-
-	}
-	return out, nil
-}
-
-type IpSubnetMapper struct {
-	Ipv4s       []net.IP
-	NetworkAddr net.IP
-	Mask        net.IPMask
-}
-
-func RefactorGetAllRemAddr(addr string) (*AllAddress, error) {
-	//	out := &AllAddress{}
+func GetNetworkAddresses(addr string) (*IpSubnetMapper, error) {
 	ipmap := &IpSubnetMapper{Ipv4s: []net.IP{}}
 	ipmap := &IpSubnetMapper{Ipv4s: []net.IP{}}
+
 	ip, net, err := net.ParseCIDR(addr)
 	ip, net, err := net.ParseCIDR(addr)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	mask, err := strconv.Atoi(strings.Split(addr, "/")[1])
+	if err != nil {
+		return nil, err
+	}
 	ipmap.NetworkAddr = ip.Mask(net.Mask)
 	ipmap.NetworkAddr = ip.Mask(net.Mask)
-	ipmap.Mask = ip.DefaultMask()
-	fmt.Printf("%+v\n", ip.Mask(net.Mask))
-	fmt.Println(ip.DefaultMask())
-	fmt.Printf("%s\n", net.IP.DefaultMask())
-
-	addressRecurse(ipmap, 2000)
+	ipmap.Mask = mask
+	ipmap.Current = ip.Mask(net.Mask)
+	addressRecurse(ipmap, 65535)
 
 
-	return nil, nil
+	return ipmap, nil
 
 
 }
 }

+ 112 - 20
pkg/local_test.go

@@ -2,58 +2,150 @@ package kyoketsu
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"log"
 	"log"
-	"net/netip"
+	"net"
 	"os"
 	"os"
 	"testing"
 	"testing"
-
-	"github.com/google/go-cmp/cmp"
 )
 )
 
 
-func LoadTestAddresses(loc string) *AllAddress {
+type IpAddresses struct {
+	Addrs []string `json:"addresses"`
+}
+
+func LoadTestAddresses(loc string) map[string]struct{} {
 	b, err := os.ReadFile(loc)
 	b, err := os.ReadFile(loc)
 	if err != nil {
 	if err != nil {
 		log.Fatal("Test setup failed.\n", err)
 		log.Fatal("Test setup failed.\n", err)
 	}
 	}
-	var alladdr AllAddress
-	err = json.Unmarshal(b, &alladdr)
+	var addr IpAddresses
+	addrmap := map[string]struct{}{}
+	err = json.Unmarshal(b, &addr)
 	if err != nil {
 	if err != nil {
 		log.Fatal("test setup failed.\n", err)
 		log.Fatal("test setup failed.\n", err)
 	}
 	}
-	return &alladdr
+	for i := range addr.Addrs {
+		addrmap[addr.Addrs[i]] = struct{}{}
+	}
+	return addrmap
 
 
 }
 }
 
 
 // Testing the addres recursion function to return all IPs in the target address subnet
 // Testing the addres recursion function to return all IPs in the target address subnet
+// All test cases use a select sort to assert that all addresses in the test data are in the return
 func TestAddressRecurse(t *testing.T) {
 func TestAddressRecurse(t *testing.T) {
 	type TestCase struct {
 	type TestCase struct {
 		Name       string
 		Name       string
-		Wants      *AllAddress
-		Input      string
+		TestData   string
+		InputAddr  string
+		InputMask  int
 		ShouldFail bool
 		ShouldFail bool
 	}
 	}
 
 
 	tc := []TestCase{
 	tc := []TestCase{
 		TestCase{
 		TestCase{
-			Name:  "Passing testcase with valid IP address, returns all addresses.",
-			Wants: LoadTestAddresses("../test/local_ips.json"),
-			Input: "192.168.50.50",
+			Name:      "Passing testcase with valid IP address, returns all addresses.",
+			TestData:  "../test/local_ips.json",
+			InputAddr: "192.168.50.50",
+			InputMask: 24,
 		},
 		},
 		TestCase{
 		TestCase{
-			Name:  "Passing testcase with valid IP address that belongs to a /16 subnet",
-			Wants: LoadTestAddresses("../test/slash16_ips.json"),
-			Input: "10.252.1.0",
+			Name:      "Passing testcase with valid IP address that belongs to a /16 subnet",
+			TestData:  "../test/slash16_ips.json",
+			InputAddr: "10.252.1.1",
+			InputMask: 16,
 		},
 		},
 	}
 	}
 	for i := range tc {
 	for i := range tc {
-		addr, err := netip.ParseAddr(tc[i].Input)
+		addr, network, err := net.ParseCIDR(fmt.Sprintf("%s/%v", tc[i].InputAddr, tc[i].InputMask))
 		if err != nil {
 		if err != nil {
 			t.Errorf("Test case: '%s' failed! Reason: %s", tc[i].Name, err)
 			t.Errorf("Test case: '%s' failed! Reason: %s", tc[i].Name, err)
 		}
 		}
-		got := &AllAddress{}
-		addressRecurse(addr, got, 65535)
-		if !cmp.Equal(got, tc[i].Wants) {
-			t.Errorf("Test case: '%s' failed! Got: %+v\nWant: %+v\n", tc[i].Name, got, tc[i].Wants)
+		got := &IpSubnetMapper{}
+		got.Mask = tc[i].InputMask
+		got.NetworkAddr = addr.Mask(network.Mask)
+		got.Current = addr.Mask(network.Mask)
+		addressRecurse(got, 65535)
+		want := LoadTestAddresses(tc[i].TestData)
+		for x := range got.Ipv4s {
+			gotip := got.Ipv4s[x]
+			_, ok := want[gotip.String()]
+			if !ok {
+				t.Errorf("Test '%s' failed! Address: %s was not found in the test data: %s\n", tc[i].Name, gotip.String(), tc[i].TestData)
+			}
+
+		}
+		log.Printf("Nice! Test: '%s' passed!\n", tc[i].Name)
+	}
+
+}
+
+// Testing the function to retrieve the next network address
+func TestGetNextAddr(t *testing.T) {
+	type TestCase struct {
+		Name       string
+		Input      string
+		Wants      string
+		ShouldFail bool
+	}
+
+	tc := []TestCase{
+		TestCase{
+			Name:       "Passing test case, function returns the next address",
+			Input:      "10.252.1.1",
+			Wants:      "10.252.1.2",
+			ShouldFail: false,
+		},
+		TestCase{
+			Name:       "Failing test case, function returns the wrong address",
+			Input:      "10.252.1.1",
+			Wants:      "10.252.1.4",
+			ShouldFail: true,
+		},
+	}
+	for i := range tc {
+		got := getNextAddr(tc[i].Input)
+		if got != tc[i].Wants {
+			if !tc[i].ShouldFail {
+				t.Errorf("Test: '%s' failed! Return: %s\nTest expected: %s\nTest Should fail: %v\n", tc[i].Name, got, tc[i].Wants, tc[i].ShouldFail)
+			}
+
+		}
+	}
+
+}
+
+func TestGetNetwork(t *testing.T) {
+	type TestCase struct {
+		Name       string
+		InputAddr  string
+		InputMask  int
+		Expects    string
+		ShouldFail bool
+	}
+	tc := []TestCase{
+		TestCase{
+			Name:       "Passing test, function returns the correct network given the CIDR mask",
+			InputAddr:  "192.168.50.35",
+			InputMask:  24,
+			Expects:    "192.168.50.0",
+			ShouldFail: false,
+		},
+		TestCase{
+			Name:       "Passing test, function returns the correct network given the CIDR mask (Larger network, /16 CIDR)",
+			InputAddr:  "10.252.47.200",
+			InputMask:  16,
+			Expects:    "10.252.0.0",
+			ShouldFail: false,
+		},
+	}
+	for i := range tc {
+		got := getNetwork(tc[i].InputAddr, tc[i].InputMask)
+		if got != tc[i].Expects {
+			if !tc[i].ShouldFail {
+				t.Errorf("Test: '%s' failed! Returned: %s\nExpected: %s\nShould fail: %v", tc[i].Name, got, tc[i].Expects, tc[i].ShouldFail)
+			}
+
 		}
 		}
 
 
 	}
 	}

File diff suppressed because it is too large
+ 0 - 1
test/slash16_ips.json


Some files were not shown because too many files changed in this diff