//go:build (linux && !android) || freebsd

package dns

import (
	"bytes"
	"fmt"
	"net/netip"
	"os"
	"strings"

	log "github.com/sirupsen/logrus"

	"github.com/netbirdio/netbird/client/internal/statemanager"
)

const (
	fileGeneratedResolvConfContentHeader         = "# Generated by NetBird"
	fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
# The original file can be restored from ` + fileDefaultResolvConfBackupLocation + "\n\n"

	fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"

	fileMaxLineCharsLimit        = 256
	fileMaxNumberOfSearchDomains = 6
)

type fileConfigurator struct {
	repair              *repair
	originalPerms       os.FileMode
	nbNameserverIP      netip.Addr
	originalNameservers []netip.Addr
}

func newFileConfigurator() (*fileConfigurator, error) {
	fc := &fileConfigurator{}
	fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
	return fc, nil
}

func (f *fileConfigurator) supportCustomPort() bool {
	return false
}

func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
	if !f.isBackupFileExist() {
		if err := f.backup(); err != nil {
			return fmt.Errorf("backup resolv.conf: %w", err)
		}
	}

	nbSearchDomains := searchDomains(config)
	f.nbNameserverIP = config.ServerIP

	resolvConf, err := parseBackupResolvConf()
	if err != nil {
		log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
	}

	f.originalNameservers = resolvConf.nameServers

	f.repair.stopWatchFileChanges()

	err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
	if err != nil {
		return err
	}
	f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
	return nil
}

// getOriginalNameservers returns the nameservers that were found in the original resolv.conf
func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
	return f.originalNameservers
}

func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP netip.Addr, cfg *resolvConf, stateManager *statemanager.Manager) error {
	searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)

	buf := prepareResolvConfContent(
		searchDomainList,
		[]string{nbNameserverIP.String()},
		cfg.others,
	)

	log.Debugf("creating managed file %s", defaultResolvConfPath)
	err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
	if err != nil {
		restoreErr := f.restore()
		if restoreErr != nil {
			log.Errorf("attempt to restore default file failed with error: %s", err)
		}
		return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err)
	}

	log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)

	// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
	if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
		log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
	}

	return nil
}

func (f *fileConfigurator) restoreHostDNS() error {
	f.repair.stopWatchFileChanges()
	return f.restore()
}

func (f *fileConfigurator) string() string {
	return "file"
}

func (f *fileConfigurator) backup() error {
	stats, err := os.Stat(defaultResolvConfPath)
	if err != nil {
		return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err)
	}

	f.originalPerms = stats.Mode()

	err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
	if err != nil {
		return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err)
	}
	return nil
}

func (f *fileConfigurator) restore() error {
	if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
		return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
	}

	return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}

func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
	resolvConf, err := parseDefaultResolvConf()
	if err != nil {
		return fmt.Errorf("parse current resolv.conf: %w", err)
	}

	// no current nameservers set -> restore
	if len(resolvConf.nameServers) == 0 {
		return restoreResolvConfFile()
	}

	// current address is still netbird's non-available dns address -> restore
	currentDNSAddress := resolvConf.nameServers[0]
	if currentDNSAddress == storedDNSAddress {
		return restoreResolvConfFile()
	}

	log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
	return nil
}

func (f *fileConfigurator) isBackupFileExist() bool {
	_, err := os.Stat(fileDefaultResolvConfBackupLocation)
	return err == nil
}

func restoreResolvConfFile() error {
	log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)

	if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil {
		return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
	}

	return nil
}

func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
	var buf bytes.Buffer

	buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)

	for _, cfgLine := range others {
		buf.WriteString(cfgLine)
		buf.WriteByte('\n')
	}

	if len(searchDomains) > 0 {
		buf.WriteString("search ")
		buf.WriteString(strings.Join(searchDomains, " "))
		buf.WriteByte('\n')
	}

	for _, ns := range nameServers {
		buf.WriteString("nameserver ")
		buf.WriteString(ns)
		buf.WriteByte('\n')
	}

	return buf
}

func searchDomains(config HostDNSConfig) []string {
	listOfDomains := make([]string, 0)
	for _, dConf := range config.Domains {
		if dConf.MatchOnly || dConf.Disabled {
			continue
		}

		listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
	}
	return listOfDomains
}

// merge search Domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
	lineSize := len("search")
	searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))

	lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
	_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)

	return searchDomainsList
}

// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
// extend s slice with vs elements
// return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
	for _, sd := range vs {
		duplicated := false
		for _, fs := range *s {
			if fs == sd {
				duplicated = true
				break
			}

		}

		if duplicated {
			continue
		}

		tmpCharsNumber := initialLineChars + 1 + len(sd)
		if tmpCharsNumber > fileMaxLineCharsLimit {
			// lets log all skipped Domains
			log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
			continue
		}

		initialLineChars = tmpCharsNumber

		if len(*s) >= fileMaxNumberOfSearchDomains {
			// lets log all skipped Domains
			log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
			continue
		}
		*s = append(*s, sd)
	}

	return initialLineChars
}

func copyFile(src, dest string) error {
	stats, err := os.Stat(src)
	if err != nil {
		return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err)
	}

	bytesRead, err := os.ReadFile(src)
	if err != nil {
		return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err)
	}

	err = os.WriteFile(dest, bytesRead, stats.Mode())
	if err != nil {
		return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err)
	}
	return nil
}

func isContains(subList []string, list []string) bool {
	for _, sl := range subList {
		var found bool
		for _, l := range list {
			if sl == l {
				found = true
			}
		}
		if !found {
			return false
		}
	}
	return true
}
