Răsfoiți Sursa

concurrency cleanup

aeth 7 luni în urmă
părinte
comite
499d8f566c
4 a modificat fișierele cu 73 adăugiri și 43 ștergeri
  1. 22 15
      pkg/html/templates/home.html
  2. 8 4
      pkg/scanner.go
  3. 7 7
      pkg/storage.go
  4. 36 17
      pkg/webserver.go

+ 22 - 15
pkg/html/templates/home.html

@@ -24,30 +24,37 @@
                               <button type="submit">scan</button>
                         </form>
                     </div>
-                    <div class="col-2">
-                        <a style="color: white; font-family: monospace;">Filter on network address:</a>
-                        <form method="post"
-                              hx-target="#response-div"
-                              hx-post="/subnets"
-                              hx-ext="json-enc">
-                              <input type="text" name="ip_address" placeholder="192.168.50.0" required>
-                              <button type="submit">filter</button>
-                        </form>
-                    </div>
-                    <div class="col-2">
-                        <a style="color: white; font-family: monospace;">Exclude FQDN pattern:</a>
+                        <div class="col-4">
                         <form method="post"
                               hx-target="#response-div"
                               hx-post="/excludefqdn"
                               hx-ext="json-enc">
-                              <input type="text" name="fqdn_pattern" placeholder="wasted-domain.xyz" required>
-                              <button type="submit">filter</button>
+                            <div class="row container-fluid p-1">
+                                <div class="col-auto">
+                                    <div class="row">
+                                    <a style="color: white; font-family: monospace;">Exclude FQDN pattern:</a>
+                                </div>
+                                <div class="row">
+                                    <input type="text" name="fqdn_pattern" placeholder="wasted-domain.xyz">
+                                </div>
+                                </div>
+                                <div class="col-auto">
+                                    <div class="row">
+                                        <a style="color: white; font-family: monospace;">Filter on network address:</a>
+                                    </div>
+                                    <div class="row">
+                                    <input type="text" name="network_address" placeholder="192.168.50.0">
+                                    </div>
+                                    <div class="col-2">
+                                    <button type="submit">filter</button>
+                                </div>
+                            </div>
                         </form>
+                        </div>
                     </div>
                 </div>
             </div>
         </div>
-
         <table class="table table-dark table-bordered table-hover table-">
             <thead>
                 <tr>

+ 8 - 4
pkg/scanner.go

@@ -128,7 +128,11 @@ func singlePortScan(addr string, port int) int {
 	if err != nil {
 		return 0
 	}
-	conn.Close()
+	defer func() {
+		if conn != nil {
+			conn.Close()
+		}
+	}()
 	return port
 }
 
@@ -143,8 +147,8 @@ func NetSweep(ips []net.IP, cidr int, ports []int, scanned chan Host) {
 	network := getNetwork(ips[0].String(), cidr)
 	for i := range ips {
 		wg.Add(1)
-		go func(target string, ntwrk string, portnum []int, wgrp *sync.WaitGroup, output chan Host) {
-			defer wgrp.Done()
+		go func(target string, ntwrk string, portnum []int, output chan Host) {
+			defer wg.Done()
 			portscanned := PortWalk(target, portnum)
 			output <- Host{
 				Fqdn:           getFqdn(target),
@@ -153,7 +157,7 @@ func NetSweep(ips []net.IP, cidr int, ports []int, scanned chan Host) {
 				PortString:     strings.Trim(strings.Join(strings.Fields(fmt.Sprint(portscanned)), ","), "[]"),
 				Network:        ntwrk,
 			}
-		}(ips[i].String(), network, ports, wg, scanned)
+		}(ips[i].String(), network, ports, scanned)
 
 	}
 	wg.Wait()

+ 7 - 7
pkg/storage.go

@@ -45,7 +45,7 @@ type TopologyDatabaseIO interface {
 	Create(host Host) (*Host, error)
 	All() ([]Host, error)
 	GetByNetwork(network string) ([]Host, error)
-	FilterDnsPattern(patterns []string) ([]Host, error)
+	FilterDnsPattern(network string, patterns []string) ([]Host, error)
 	GetByIP(ip string) (*Host, error)
 	Update(id int64, updated Host) (*Host, error)
 	Delete(id int64) error
@@ -204,17 +204,18 @@ func (r *SQLiteRepo) GetByNetwork(network string) ([]Host, error) {
 	return hosts, nil
 }
 
-func (r *SQLiteRepo) FilterDnsPattern(patterns []string) ([]Host, error) {
+func (r *SQLiteRepo) FilterDnsPattern(network string, patterns []string) ([]Host, error) {
 	var queryBuilder strings.Builder
-	queryBuilder.WriteString("SELECT * FROM hosts WHERE ")
-	args := make([]interface{}, len(patterns))
+	queryBuilder.WriteString("SELECT * FROM hosts WHERE network LIKE ? AND")
+	args := make([]interface{}, len(patterns)+1)
+	args[0] = network
 
 	for i, pattern := range patterns {
 		if i > 0 {
 			queryBuilder.WriteString(" AND ")
 		}
-		queryBuilder.WriteString("fqdn NOT LIKE ?")
-		args[i] = "%" + pattern + "%"
+		queryBuilder.WriteString(" fqdn NOT LIKE ?")
+		args[i+1] = "%" + pattern + "%"
 	}
 
 	rows, err := r.db.Query(queryBuilder.String(), args...)
@@ -223,7 +224,6 @@ func (r *SQLiteRepo) FilterDnsPattern(patterns []string) ([]Host, error) {
 	}
 	var hosts []Host
 	defer rows.Close()
-
 	for rows.Next() {
 		var host Host
 		if err := rows.Scan(&host.Id, &host.Fqdn, &host.IpAddress, &host.PortString, &host.Network); err != nil {

+ 36 - 17
pkg/webserver.go

@@ -14,8 +14,9 @@ import (
 )
 
 type ScanRequest struct {
-	IpAddress   string `json:"ip_address"`
-	FqdnPattern string `json:"fqdn_pattern"`
+	IpAddress      string `json:"ip_address"`
+	NetworkAddress string `json:"network_address"`
+	FqdnPattern    string `json:"fqdn_pattern"`
 }
 
 // Holding all static web server resources
@@ -38,7 +39,7 @@ func RunHttpServer(port int, dbhook TopologyDatabaseIO, portmap []int, logStream
 	if err != nil {
 		log.Fatal(err)
 	}
-	htmlHndl := &HtmlHandler{Home: tmpl, TableEntry: iptable, DbHook: dbhook}
+	htmlHndl := &HtmlHandler{Home: tmpl, TableEntry: iptable, DbHook: dbhook, stream: logStream}
 	execHndl := &ExecutionHandler{DbHook: dbhook, PortMap: portmap, TableEntry: iptable, stream: logStream}
 	http.Handle("/static/", assets)
 	http.Handle("/home", htmlHndl)
@@ -58,7 +59,7 @@ type ExecutionHandler struct {
 }
 
 func (e *ExecutionHandler) Log(vals ...string) {
-	e.stream.Write([]byte("KYOKETSU-WEB LOG ||| " + strings.Join(vals, " ||| ")))
+	e.stream.Write([]byte("KYOKETSU-WEB LOG ||| " + strings.Join(vals, " ||| ") + "\n"))
 
 }
 
@@ -69,11 +70,13 @@ Top level function to be routed to, this will spawn a suite of goroutines that w
 	:param r: a pointer to the request coming in from the client
 */
 func (e *ExecutionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	e.Log("Recieved: " + r.Method + " on path: " + r.RequestURI)
 	input, err := e.parseRequest(r)
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
+	e.Log(fmt.Sprintf("Parsed request struct: %+v", input))
 
 	subnetMap, err := GetNetworkAddresses(input.IpAddress)
 	if err != nil {
@@ -81,17 +84,19 @@ func (e *ExecutionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	scanned := make(chan Host)
+	scanned := make(chan Host, 1000)
 	var wg sync.WaitGroup
 	var mu sync.Mutex
 	var errorRaised bool
 
 	wg.Add(1)
+
 	go e.processScannedData(w, e.TableEntry, scanned, &wg, &mu, &errorRaised)
 
 	NetSweep(subnetMap.Ipv4s, subnetMap.Mask, RetrieveScanDirectives(), scanned)
-	close(scanned)
+	e.Log("Waiting for execution group to return finish.")
 	wg.Wait()
+	e.Log("Execution group finished scanning.")
 
 	if errorRaised {
 		http.Error(w, "Error during scan processing. Check logs for details.", http.StatusInternalServerError)
@@ -170,17 +175,22 @@ type HtmlHandler struct {
 	Home       *template.Template // pointer to the HTML homepage
 	TableEntry *template.Template // pointer to the table entry html template
 	DbHook     TopologyDatabaseIO
+	stream     io.Writer
 }
 
-func (h *HtmlHandler) handleHome(w http.ResponseWriter, r *http.Request) {
-	if r.RequestURI == "/home" {
-		data, err := h.DbHook.All()
-		if err != nil {
-			http.Error(w, "There was an error reading from the database: "+err.Error(), http.StatusInternalServerError)
-		}
-		h.Home.Execute(w, data)
-		return
+func (h *HtmlHandler) Log(vals ...string) {
+	h.stream.Write([]byte("KYOKETSU-WEB LOG ||| " + strings.Join(vals, " ||| ") + "\n"))
+
+}
+
+func (h *HtmlHandler) handleHome(w http.ResponseWriter) {
+	data, err := h.DbHook.All()
+	if err != nil {
+		h.Log("Error reading from database: " + err.Error())
+		http.Error(w, "There was an error reading from the database: "+err.Error(), http.StatusInternalServerError)
 	}
+	h.Home.Execute(w, data)
+	return
 
 }
 
@@ -197,7 +207,7 @@ func (h *HtmlHandler) subnetQueryHandler(w http.ResponseWriter, r *http.Request)
 		http.Error(w, "There was an error reading the request: "+err.Error(), http.StatusBadRequest)
 		return
 	}
-	data, err := h.DbHook.GetByNetwork(req.IpAddress)
+	data, err := h.DbHook.GetByNetwork(req.NetworkAddress)
 	if err != nil {
 		http.Error(w, "There was an error reading the request: "+err.Error(), http.StatusBadRequest)
 		return
@@ -222,7 +232,15 @@ func (h *HtmlHandler) fqdnQueryHandler(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	dnsList := strings.Split(req.FqdnPattern, ",")
-	data, err := h.DbHook.FilterDnsPattern(dnsList)
+	var ntwrk string
+	if req.NetworkAddress == "" {
+		ntwrk = "%"
+	} else {
+		ntwrk = req.NetworkAddress
+	}
+	h.Log("Query Arguments: " + ntwrk + " " + req.FqdnPattern)
+
+	data, err := h.DbHook.FilterDnsPattern(ntwrk, dnsList)
 	if err != nil {
 		http.Error(w, "There was an error reading the request: "+err.Error(), http.StatusBadRequest)
 		return
@@ -240,9 +258,10 @@ Handler function for HTML serving
 	:param r: pointer to the http.Request coming in
 */
 func (h *HtmlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	h.Log("Recieved " + r.Method + " on path: " + r.RequestURI)
 	switch r.RequestURI {
 	case "/home":
-		h.handleHome(w, r)
+		h.handleHome(w)
 	case "/subnets":
 		h.subnetQueryHandler(w, r)
 	case "/excludefqdn":