소스 검색

fixed memory leak! as well as improved the overall memory usage of the program. lookin good!

AETH-erial 10 달 전
부모
커밋
9847fad52c
6개의 변경된 파일72개의 추가작업 그리고 58개의 파일을 삭제
  1. 0 2
      cmd/kyoketsu-web/kyoketsu-web.go
  2. 11 22
      cmd/kyoketsu/kyoketsu.go
  3. 1 1
      pkg/html/templates/ip_table.html
  4. 19 11
      pkg/scanner.go
  5. 5 5
      pkg/storage.go
  6. 36 17
      pkg/webserver.go

+ 0 - 2
cmd/kyoketsu-web/kyoketsu-web.go

@@ -32,7 +32,6 @@ import (
 	"flag"
 	"log"
 	"net/http"
-	"os"
 	"sync"
 
 	kyoketsu "git.aetherial.dev/aeth/kyoketsu/pkg"
@@ -46,7 +45,6 @@ func main() {
 	debug := flag.Bool("debug", false, "Pass this to start the pprof server")
 	flag.Parse()
 
-	os.Remove(dbfile) // TODO: remove this once i add more smart interaction with the DB
 	db, err := sql.Open("sqlite3", dbfile)
 	if err != nil {
 		log.Fatal(err)

+ 11 - 22
cmd/kyoketsu/kyoketsu.go

@@ -65,7 +65,7 @@ func main() {
 		log.Fatal(err)
 	}
 	prompt := promptui.Select{
-		Label: "Select the network you wish to scan.",
+		Label: "Select the network you wish to scan",
 		Items: localAddr.Choice,
 	}
 	choice, _, err := prompt.Run()
@@ -78,26 +78,15 @@ func main() {
 	if err != nil {
 		log.Fatal(err)
 	}
-	out := kyoketsu.NetSweep(addr.Ipv4s, kyoketsu.RetrieveScanDirectives())
-	for x := range out {
-		fmt.Printf("%+v\n", x)
-	}
-	/*
-		for i := range addr.Ipv4s {
-			wg.Add(1)
-			go func(target string, wg *sync.WaitGroup) {
-				defer wg.Done()
-				out := kyoketsu.PortWalk(target, kyoketsu.RetrieveScanDirectives())
-				if len(out) > 0 {
-
-					fmt.Print(" |-|-|-| :::: HOST FOUND :::: |-|-|-|\n==================||==================\n")
-					fmt.Printf("IPv4 Address: %s\nListening Ports: %v\n=====================================\n", target, out)
-
-				}
-
-			}(addr.Ipv4s[i].String(), wg)
-
+	scanned := make(chan kyoketsu.Host)
+	go func() {
+		for x := range scanned {
+			if len(x.ListeningPorts) > 0 {
+				fmt.Print(" |-|-|-| :::: HOST FOUND :::: |-|-|-|\n==================||==================\n")
+				fmt.Printf("IPv4 Address: %s\nFully Qualified Domain Name: %s\nListening Ports: %v\n=====================================\n", x.IpAddress, x.Fqdn, x.ListeningPorts)
+			}
 		}
-		wg.Wait()
-	*/
+	}()
+	kyoketsu.NetSweep(addr.Ipv4s, kyoketsu.RetrieveScanDirectives(), scanned)
+
 }

+ 1 - 1
pkg/html/templates/ip_table.html

@@ -3,6 +3,6 @@
     <div class="col border border-white text-white"><p class="font-monospace fs-3">{{ .Fqdn }}</p></div>
     <div class="col border border-white text-white"><p class="font-monospace fs-3">{{ .IpAddress }}</p></div>
     <div class="col border border-white text-white"><p class="font-monospace fs-3">{{ .PingResponse }}</p></div>
-    <div class="col border border-white text-white"><p class="font-monospace fs-3">{{ .PortString }}</p></div>
+    <div class="col border border-white text-white"><p class="font-monospace fs-3">{{ .ListeningPorts }}</p></div>
 </div>
 {{ end }}

+ 19 - 11
pkg/scanner.go

@@ -28,6 +28,8 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"strings"
+	"sync"
 	"syscall"
 	"time"
 )
@@ -112,26 +114,32 @@ Perform a port scan sweep across an entire subnet
 	:param ip: the IPv4 address WITH CIDR notation
 	:param portmap: the mapping of ports to scan with (port number mapped to protocol name)
 */
-func NetSweep(ips []net.IP, ports []int) []Host {
-	scanned := make(chan Host)
+func NetSweep(ips []net.IP, ports []int, scanned chan Host) {
+
+	wg := &sync.WaitGroup{}
 	for i := range ips {
-		go func(target string, portnum []int) {
+		wg.Add(1)
+		go func(target string, portnum []int, wgrp *sync.WaitGroup) {
+			defer wgrp.Done()
 			scanned <- Host{
+				Fqdn:           getFqdn(target),
 				IpAddress:      target,
 				ListeningPorts: PortWalk(target, portnum),
 			}
 
-		}(ips[i].String(), ports)
-	}
-	var hosts []Host
-	for x := range scanned {
-		fmt.Printf("%+v\n", x)
-
-		hosts = append(hosts, x)
+		}(ips[i].String(), ports, wg)
 	}
+	wg.Wait()
+	close(scanned)
 
-	return hosts
+}
 
+func getFqdn(ip string) string {
+	names, err := net.LookupAddr(ip)
+	if err != nil {
+		return "not found with default resolver"
+	}
+	return strings.Join(names, ", ")
 }
 
 /*

+ 5 - 5
pkg/storage.go

@@ -74,7 +74,7 @@ func (r *SQLiteRepo) Migrate() error {
         id INTEGER PRIMARY KEY AUTOINCREMENT,
         fqdn TEXT NOT NULL,
         ipv4_address TEXT NOT NULL UNIQUE,
-        listening_port TEXT NOT NULL
+        listening_port INTEGER[] NOT NULL
     );
     `
 
@@ -88,7 +88,7 @@ Create an entry in the hosts table
 	:param host: a Host entry from a port scan
 */
 func (r *SQLiteRepo) Create(host Host) (*Host, error) {
-	res, err := r.db.Exec("INSERT INTO hosts(fqdn, ipv4_address, listening_port) values(?,?,?)", host.Fqdn, host.IpAddress, host.PortString)
+	res, err := r.db.Exec("INSERT INTO hosts(fqdn, ipv4_address, listening_port) values(?,?,?)", host.Fqdn, host.IpAddress, host.ListeningPorts)
 	if err != nil {
 		var sqliteErr sqlite3.Error
 		if errors.As(err, &sqliteErr) {
@@ -119,7 +119,7 @@ func (r *SQLiteRepo) All() ([]Host, error) {
 	var all []Host
 	for rows.Next() {
 		var host Host
-		if err := rows.Scan(&host.Id, &host.Fqdn, &host.IpAddress, &host.PortString); err != nil {
+		if err := rows.Scan(&host.Id, &host.Fqdn, &host.IpAddress, &host.ListeningPorts); err != nil {
 			return nil, err
 		}
 		all = append(all, host)
@@ -132,7 +132,7 @@ func (r *SQLiteRepo) GetByIP(ip string) (*Host, error) {
 	row := r.db.QueryRow("SELECT * FROM hosts WHERE ipv4_address = ?", ip)
 
 	var host Host
-	if err := row.Scan(&host.Id, &host.Fqdn, &host.IpAddress, &host.PortString); err != nil {
+	if err := row.Scan(&host.Id, &host.Fqdn, &host.IpAddress, &host.ListeningPorts); err != nil {
 		if errors.Is(err, sql.ErrNoRows) {
 			return nil, ErrNotExists
 		}
@@ -146,7 +146,7 @@ func (r *SQLiteRepo) Update(id int64, updated Host) (*Host, error) {
 	if id == 0 {
 		return nil, errors.New("invalid updated ID")
 	}
-	res, err := r.db.Exec("UPDATE hosts SET fqdn = ?, ipv4_address = ?, listening_port = ? WHERE id = ?", updated.Fqdn, updated.IpAddress, updated.PortString, id)
+	res, err := r.db.Exec("UPDATE hosts SET fqdn = ?, ipv4_address = ?, listening_port = ? WHERE id = ?", updated.Fqdn, updated.IpAddress, updated.ListeningPorts, id)
 	if err != nil {
 		return nil, err
 	}

+ 36 - 17
pkg/webserver.go

@@ -70,34 +70,53 @@ func (e *ExecutionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		fmt.Fprintf(w, "There was an error processing your request: %s", err)
 	}
-	scanned := NetSweep(subnetMap.Ipv4s, RetrieveScanDirectives())
-	for i := range scanned {
-		if len(scanned[i].ListeningPorts) > 0 {
-			e.TableEntry.Execute(w, scanned[i])
-			/*
-				rec, err := e.DbHook.GetByIP(host.IpAddress)
+	scanned := make(chan Host)
+	go func() {
+		for x := range scanned {
+			if len(x.ListeningPorts) > 0 {
+				e.TableEntry.Execute(w, x)
+
+				fmt.Print(" |-|-|-| :::: HOST FOUND :::: |-|-|-|\n==================||==================\n")
+				fmt.Printf("IPv4 Address: %s\nListening Ports: %v\n=====================================\n", x.IpAddress, x.ListeningPorts)
+				host, err := e.DbHook.GetByIP(x.IpAddress)
 				if err != nil {
 					if err != ErrNotExists {
-						log.Fatal(err, "There was a fatal error querying the database\n")
+						log.Fatal(err, " Couldnt access the database. Fatal error.\n")
 					}
-					_, err = e.DbHook.Create(host)
+					_, err = e.DbHook.Create(x)
 					if err != nil {
-						log.Fatal(err, "There was a fatal error creating this record in the database.\n")
+						log.Fatal(err, " Fatal error trying to read the database.\n")
 					}
 					continue
 				}
-				_, err = e.DbHook.Update(rec.Id, host)
+				_, err = e.DbHook.Update(host.Id, x)
 				if err != nil {
-					log.Fatal(err, "there was a fatal error updating this record: ", rec)
+					log.Fatal(err, " fatal error when updating a record.\n")
 				}
-				continue
-			*/
+
+			}
+		}
+	}()
+	NetSweep(subnetMap.Ipv4s, RetrieveScanDirectives(), scanned)
+	/*
+		rec, err := e.DbHook.GetByIP(host.IpAddress)
+		if err != nil {
+			if err != ErrNotExists {
+				log.Fatal(err, "There was a fatal error querying the database\n")
+			}
+			_, err = e.DbHook.Create(host)
+			if err != nil {
+				log.Fatal(err, "There was a fatal error creating this record in the database.\n")
+			}
 			continue
-			// not adding these to the database because this isnt really that important
 		}
-
-	}
-
+		_, err = e.DbHook.Update(rec.Id, host)
+		if err != nil {
+			log.Fatal(err, "there was a fatal error updating this record: ", rec)
+		}
+		continue
+	*/
+	// not adding these to the database because this isnt really that important
 }
 
 // handlers //