diff --git a/README.md b/README.md index 828e912..ab287ec 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,36 @@ -# Simple DNS Server +# 🌎 Simple DNS server +A tiny DNS server that is capable of serving records configured in a MySQL table, or configured statically in a JSON file -## Installation -1. Create a file with the path `/etc/systemd/system/argondns.service` and with the following content: +## 🧐 Configuration documentation + +- `mode`: Can be either `db` if your records are stored in a MySQL database, or `static_records` if your records are static and stored in the configuration JSON file. + +- `db`: The MySQL server & database credentials. This works only if `mode` is set to `db` + +- `listener`: The listening/bind settings for the DNS server (usually has to be kept binding on port 53 to be able to accept DNS requests). + +- `process_unstored_dns_queries`: Should the DNS server also accept queries of records that are not stored in your database table/static records configuration? Enable this if yes. + +- `static_records`: Configure your static records here, one per JSON array. This works only if `mode` is set to `static_records` + +## 🛠️ Installation as a service + +1. Store your configuration file at `/etc/simpledns/config.json` + You can copy the example configuration file and change it to serve your needs. +2. If running Simple DNS server in the `db` mode, use this database structure for your records table: https://github.com/oddmario/simple-dns-server/blob/main/db_structure.sql +3. Place the binary file of Simple DNS server at `/usr/local/bin` (e.g. `/usr/local/bin/simpledns`) +4. Make the binary file executable: `chmod u+x /usr/local/bin/simpledns` +5. Create a systemd service for the application. This can be done by creating /etc/systemd/system/simpledns.service to have this content: ``` [Unit] -Description=ArgonDNS +Description=SimpleDNSserver [Service] User=root -WorkingDirectory=/root/argondns +WorkingDirectory=/usr/local/bin LimitNOFILE=2097152 TasksMax=infinity -ExecStart=/root/argondns/simpledns_linux_amd64 +ExecStart=/usr/local/bin/simpledns /etc/simpledns/config.json Restart=on-failure StartLimitInterval=180 StartLimitBurst=30 @@ -19,16 +38,11 @@ RestartSec=5s [Install] WantedBy=multi-user.target -``` - -2. Put the `simpledns_linux_amd64` executable file at a directory with path `/root/argondns/` - so the final path of the executable will be `/root/argondns/simpledns_linux_amd64` - -3. Run `chmod -R 777 /root/argondns/simpledns_linux_amd64` - -4. Place the `config.json` file found in this repository at `/root/argondns` along with the `simpledns_linux_amd64` executable file - -5. Run `systemctl enable argondns` +``` 6. Port 53 (the DNS server port) is usually in use by default. To solve this, follow https://unix.stackexchange.com/a/676977/405697 then run `systemctl restart systemd-resolved` - -7. Make sure that there are no other DNS servers (such as bind9) are running, then run `systemctl start argondns` to start our DNS server :) \ No newline at end of file +7. Make sure that there are no other DNS servers (such as bind9) are running +8. Enable the Simple DNS server service on startup & start it now: +``` +systemctl enable --now simpledns.service +``` \ No newline at end of file diff --git a/config.json b/config.json index 50c7e51..84a9c3c 100644 --- a/config.json +++ b/config.json @@ -1,4 +1,5 @@ { + "mode": "db", "db": { "host": "localhost", "username": "root", @@ -14,5 +15,17 @@ "process_unstored_dns_queries": { "is_enabled": false, "dns_server": "8.8.8.8:53" - } + }, + "static_records": [ + { + "type": "A", + "name": "", + "value": "", + "ttl": 0, + "srv_priority": 0, + "srv_weight": 0, + "srv_port": 0, + "srv_target": "" + } + ] } \ No newline at end of file diff --git a/constants/constants.go b/constants/constants.go new file mode 100644 index 0000000..ae2f197 --- /dev/null +++ b/constants/constants.go @@ -0,0 +1,4 @@ +package constants + +var ConfigFilePath string = "" +var Version string = "v1.2" diff --git a/db/utils.go b/db/utils.go index e5e6de7..5fbf17b 100644 --- a/db/utils.go +++ b/db/utils.go @@ -2,6 +2,8 @@ package db import ( "database/sql" + "errors" + "time" ) func EasyQuery(query string, args ...any) (*sql.Rows, error) { @@ -23,3 +25,31 @@ func EasyExec(query string, args ...any) (sql.Result, error) { return res, nil } + +func RetriedDbQuery(retries int, query string, args ...any) (*sql.Rows, error) { + for range retries { + res, err := EasyQuery(query, args...) + if err != nil { + time.Sleep(1 * time.Second) + continue + } else { + return res, err + } + } + + return nil, errors.New("failed") +} + +func RetriedDbExec(retries int, query string, args ...any) (sql.Result, error) { + for range retries { + res, err := EasyExec(query, args...) + if err != nil { + time.Sleep(1 * time.Second) + continue + } else { + return res, err + } + } + + return nil, errors.New("failed") +} diff --git a/dnsparser/A.go b/dnsparser/A.go index 7e693a8..a6d8df3 100644 --- a/dnsparser/A.go +++ b/dnsparser/A.go @@ -2,52 +2,32 @@ package dnsparser import ( "mario/simple-dns-server/db" + "mario/simple-dns-server/records" "net" "github.com/miekg/dns" ) func A(m *dns.Msg, name_dot, name_nodot string) bool { - res, err := db.EasyQuery("SELECT id, record_type, record_name, record_value, record_ttl, is_disposable FROM dns_records WHERE record_name = ? AND record_type = 'A'", name_nodot) - if err != nil { - // an error has occured while preparing the SQL statement - return false - } - defer res.Close() - - var recordsFound bool = false + recordsFound, records := records.GetDNSRecord(name_nodot, "A") - for res.Next() { - recordsFound = true - - var record_id int64 - var record_type string - var record_name string - var record_value string - var record_ttl int64 - var record_isdisposable int64 + if !recordsFound { + // DNS record(s) was/were not found - err = res.Scan(&record_id, &record_type, &record_name, &record_value, &record_ttl, &record_isdisposable) - if err != nil { - // an error has occured - return false - } + return false + } + for _, record := range records { r := new(dns.A) - r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(record_ttl)} - r.A = net.ParseIP(record_value) + r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(record.TTL)} + r.A = net.ParseIP(record.Value) m.Answer = append(m.Answer, r) - if record_isdisposable > 0 { - db.EasyExec("DELETE FROM dns_records WHERE id = ?", record_id) + if record.IsDisposable { + db.RetriedDbExec(10, "DELETE FROM dns_records WHERE id = ?", record.ID) } } - if !recordsFound { - // DNS record not found in the database - return false - } - return true } diff --git a/dnsparser/CNAME.go b/dnsparser/CNAME.go index 0adca04..e939b34 100644 --- a/dnsparser/CNAME.go +++ b/dnsparser/CNAME.go @@ -2,51 +2,31 @@ package dnsparser import ( "mario/simple-dns-server/db" + "mario/simple-dns-server/records" "github.com/miekg/dns" ) func CNAME(m *dns.Msg, name_dot, name_nodot string) bool { - res, err := db.EasyQuery("SELECT id, record_type, record_name, record_value, record_ttl, is_disposable FROM dns_records WHERE record_name = ? AND record_type = 'CNAME'", name_nodot) - if err != nil { - // an error has occured while preparing the SQL statement - return false - } - defer res.Close() - - var recordsFound bool = false + recordsFound, records := records.GetDNSRecord(name_nodot, "CNAME") - for res.Next() { - recordsFound = true - - var record_id int64 - var record_type string - var record_name string - var record_value string - var record_ttl int64 - var record_isdisposable int64 + if !recordsFound { + // DNS record(s) was/were not found - err = res.Scan(&record_id, &record_type, &record_name, &record_value, &record_ttl, &record_isdisposable) - if err != nil { - // an error has occured - return false - } + return false + } + for _, record := range records { r := new(dns.CNAME) - r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(record_ttl)} - r.Target = dns.Fqdn(record_value) + r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(record.TTL)} + r.Target = dns.Fqdn(record.Value) m.Answer = append(m.Answer, r) - if record_isdisposable > 0 { - db.EasyExec("DELETE FROM dns_records WHERE id = ?", record_id) + if record.IsDisposable { + db.RetriedDbExec(10, "DELETE FROM dns_records WHERE id = ?", record.ID) } } - if !recordsFound { - // DNS record not found in the database - return false - } - return true } diff --git a/dnsparser/NS.go b/dnsparser/NS.go index a762908..646b4ac 100644 --- a/dnsparser/NS.go +++ b/dnsparser/NS.go @@ -2,51 +2,31 @@ package dnsparser import ( "mario/simple-dns-server/db" + "mario/simple-dns-server/records" "github.com/miekg/dns" ) func NS(m *dns.Msg, name_dot, name_nodot string) bool { - res, err := db.EasyQuery("SELECT id, record_type, record_name, record_value, record_ttl, is_disposable FROM dns_records WHERE record_name = ? AND record_type = 'NS'", name_nodot) - if err != nil { - // an error has occured while preparing the SQL statement - return false - } - defer res.Close() - - var recordsFound bool = false + recordsFound, records := records.GetDNSRecord(name_nodot, "NS") - for res.Next() { - recordsFound = true - - var record_id int64 - var record_type string - var record_name string - var record_value string - var record_ttl int64 - var record_isdisposable int64 + if !recordsFound { + // DNS record(s) was/were not found - err = res.Scan(&record_id, &record_type, &record_name, &record_value, &record_ttl, &record_isdisposable) - if err != nil { - // an error has occured - return false - } + return false + } + for _, record := range records { r := new(dns.NS) - r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(record_ttl)} - r.Ns = dns.Fqdn(record_value) + r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(record.TTL)} + r.Ns = dns.Fqdn(record.Value) m.Answer = append(m.Answer, r) - if record_isdisposable > 0 { - db.EasyExec("DELETE FROM dns_records WHERE id = ?", record_id) + if record.IsDisposable { + db.RetriedDbExec(10, "DELETE FROM dns_records WHERE id = ?", record.ID) } } - if !recordsFound { - // DNS record not found in the database - return false - } - return true } diff --git a/dnsparser/SRV.go b/dnsparser/SRV.go index 9e6279a..6dd9ab6 100644 --- a/dnsparser/SRV.go +++ b/dnsparser/SRV.go @@ -2,58 +2,34 @@ package dnsparser import ( "mario/simple-dns-server/db" + "mario/simple-dns-server/records" "github.com/miekg/dns" ) func SRV(m *dns.Msg, name_dot, name_nodot string) bool { - res, err := db.EasyQuery("SELECT id, record_type, record_name, record_value, record_ttl, srv_priority, srv_weight, srv_port, srv_target, is_disposable FROM dns_records WHERE record_name = ? AND record_type = 'SRV'", name_nodot) - if err != nil { - // an error has occured while preparing the SQL statement + recordsFound, records := records.GetDNSRecord(name_nodot, "SRV") + + if !recordsFound { + // DNS record(s) was/were not found + return false } - defer res.Close() - - var recordsFound bool = false - - for res.Next() { - recordsFound = true - - var record_id int64 - var record_type string - var record_name string - var record_value string - var record_ttl int64 - var srv_priority int64 - var srv_weight int64 - var srv_port int64 - var srv_target string - var record_isdisposable int64 - - err = res.Scan(&record_id, &record_type, &record_name, &record_value, &record_ttl, &srv_priority, &srv_weight, &srv_port, &srv_target, &record_isdisposable) - if err != nil { - // an error has occured - return false - } + for _, record := range records { r := new(dns.SRV) - r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: uint32(record_ttl)} - r.Priority = uint16(srv_priority) - r.Weight = uint16(srv_weight) - r.Port = uint16(srv_port) - r.Target = dns.Fqdn(srv_target) + r.Hdr = dns.RR_Header{Name: name_dot, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: uint32(record.TTL)} + r.Priority = uint16(record.SRVPriority) + r.Weight = uint16(record.SRVWeight) + r.Port = uint16(record.SRVPort) + r.Target = dns.Fqdn(record.SRVTarget) m.Answer = append(m.Answer, r) - if record_isdisposable > 0 { - db.EasyExec("DELETE FROM dns_records WHERE id = ?", record_id) + if record.IsDisposable { + db.RetriedDbExec(10, "DELETE FROM dns_records WHERE id = ?", record.ID) } } - if !recordsFound { - // DNS record not found in the database - return false - } - return true } diff --git a/go.mod b/go.mod index 0fb3937..afbc09c 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,21 @@ module mario/simple-dns-server -go 1.20 +go 1.22.6 -require github.com/miekg/dns v1.1.58 +require github.com/miekg/dns v1.1.62 require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + golang.org/x/sync v0.8.0 // indirect ) require ( - github.com/go-sql-driver/mysql v1.7.1 - github.com/tidwall/gjson v1.17.0 - golang.org/x/mod v0.14.0 // indirect - golang.org/x/net v0.20.0 // indirect - golang.org/x/sys v0.16.0 // indirect - golang.org/x/tools v0.17.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 + github.com/tidwall/gjson v1.17.3 + golang.org/x/mod v0.21.0 // indirect + golang.org/x/net v0.29.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/tools v0.24.0 // indirect ) diff --git a/go.sum b/go.sum index 6d6acef..a39770f 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,23 @@ -github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= -github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= -github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= -github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= -github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= diff --git a/main.go b/main.go index 6dc3558..6083943 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,15 @@ package main import ( + "errors" + "fmt" "log" "net" + "os" + "path/filepath" "strings" + "mario/simple-dns-server/constants" "mario/simple-dns-server/db" "mario/simple-dns-server/dnsclient" "mario/simple-dns-server/dnsparser" @@ -80,9 +85,26 @@ func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { func main() { // https://gist.github.com/walm/0d67b4fb2d5daf3edd4fad3e13b162cb + args := os.Args[1:] + + if len(args) >= 1 { + constants.ConfigFilePath = args[0] + } else { + constants.ConfigFilePath, _ = filepath.Abs("./config.json") + } + + if _, err := os.Stat(constants.ConfigFilePath); errors.Is(err, os.ErrNotExist) { + log.Fatal("The specified configuration file does not exist.") + } + + fmt.Println("[INFO] Starting Simple DNS server v" + constants.Version + "...") + utils.LoadConfig() - db.InitDb() - defer db.Db.Close() + + if utils.Config.Get("mode").String() == "db" { + db.InitDb() + defer db.Db.Close() + } workers.Init() @@ -92,7 +114,9 @@ func main() { var listenerType string = utils.Config.Get("listener.type").String() server := &dns.Server{Addr: listenerData, Net: listenerType} + log.Printf("Starting at %s\n", listenerData) + err := server.ListenAndServe() defer server.Shutdown() if err != nil { diff --git a/records/models.go b/records/models.go new file mode 100644 index 0000000..3ade9a1 --- /dev/null +++ b/records/models.go @@ -0,0 +1,14 @@ +package records + +type Record struct { + ID int64 + Type string + Name string + Value string + TTL int64 + SRVPriority int64 + SRVWeight int64 + SRVPort int64 + SRVTarget string + IsDisposable bool +} diff --git a/records/records.go b/records/records.go new file mode 100644 index 0000000..bc14e24 --- /dev/null +++ b/records/records.go @@ -0,0 +1,98 @@ +package records + +import ( + "mario/simple-dns-server/db" + "mario/simple-dns-server/utils" + + "github.com/tidwall/gjson" +) + +func GetDNSRecord(record_name_, record_type_ string) (bool, []*Record) { + records := []*Record{} + recordsFound := false + + if utils.Config.Get("mode").String() == "db" { + res, err := db.RetriedDbQuery(10, "SELECT id, record_type, record_name, record_value, record_ttl, srv_priority, srv_weight, srv_port, srv_target, is_disposable FROM dns_records WHERE record_name = ? AND record_type = ?", record_name_, record_type_) + if err != nil { + // an error has occured while preparing the SQL statement + + return recordsFound, records + } + + defer res.Close() + + for res.Next() { + recordsFound = true + + var record_id int64 + var record_type string + var record_name string + var record_value string + var record_ttl int64 + var srv_priority int64 + var srv_weight int64 + var srv_port int64 + var srv_target string + var record_isdisposable int64 + + err = res.Scan(&record_id, &record_type, &record_name, &record_value, &record_ttl, &srv_priority, &srv_weight, &srv_port, &srv_target, &record_isdisposable) + if err != nil { + // an error has occured + + recordsFound = false + + return recordsFound, records + } + + isDisposable := false + if record_isdisposable >= 1 { + isDisposable = true + } + + records = append(records, &Record{ + ID: record_id, + Type: record_type, + Name: record_name, + Value: record_value, + TTL: record_ttl, + SRVPriority: srv_priority, + SRVWeight: srv_weight, + SRVPort: srv_port, + SRVTarget: srv_target, + IsDisposable: isDisposable, + }) + } + } else { + utils.Config.Get("static_records").ForEach(func(key, value gjson.Result) bool { + var record_type string = value.Get("type").String() + var record_name string = value.Get("name").String() + var record_value string = value.Get("value").String() + var record_ttl int64 = value.Get("ttl").Int() + var srv_priority int64 = value.Get("srv_priority").Int() + var srv_weight int64 = value.Get("srv_weight").Int() + var srv_port int64 = value.Get("srv_port").Int() + var srv_target string = value.Get("srv_target").String() + + if record_name == record_name_ && record_type == record_type_ { + recordsFound = true + + records = append(records, &Record{ + ID: -1, + Type: record_type, + Name: record_name, + Value: record_value, + TTL: record_ttl, + SRVPriority: srv_priority, + SRVWeight: srv_weight, + SRVPort: srv_port, + SRVTarget: srv_target, + IsDisposable: false, + }) + } + + return true + }) + } + + return recordsFound, records +} diff --git a/utils/config.go b/utils/config.go index cbf33d2..f1cc970 100644 --- a/utils/config.go +++ b/utils/config.go @@ -1,8 +1,9 @@ package utils import ( + "log" + "mario/simple-dns-server/constants" "os" - "path/filepath" "github.com/tidwall/gjson" ) @@ -13,9 +14,14 @@ var IsProcessUnstoredQueriesEnabled bool = false var Server_ProcessUnstoredQueries string = "" func LoadConfig() bool { - path, _ := filepath.Abs("./config.json") - cfg_content, _ := os.ReadFile(path) - Config = gjson.Parse(string(cfg_content)) + cfg_content, _ := os.ReadFile(constants.ConfigFilePath) + cfgContentString := BytesToString(cfg_content) + + if !gjson.Valid(cfgContentString) { + log.Fatal("[ERROR] Malformed configuration file") + } else { + Config = gjson.Parse(cfgContentString) + } IsProcessUnstoredQueriesEnabled = Config.Get("process_unstored_dns_queries.is_enabled").Bool() Server_ProcessUnstoredQueries = Config.Get("process_unstored_dns_queries.dns_server").String() diff --git a/utils/conversions.go b/utils/conversions.go index 6eb660c..86f264e 100644 --- a/utils/conversions.go +++ b/utils/conversions.go @@ -1,6 +1,9 @@ package utils -import "strconv" +import ( + "strconv" + "unsafe" +) func StrToI64(str string) int64 { if len(str) <= 0 { @@ -31,3 +34,15 @@ func StrToI(str string) int { } return res } + +// https://josestg.medium.com/140x-faster-string-to-byte-and-byte-to-string-conversions-with-zero-allocation-in-go-200b4d7105fc +func BytesToString(b []byte) string { + // Ignore if your IDE shows an error here; it's a false positive. + p := unsafe.SliceData(b) + return unsafe.String(p, len(b)) +} +func StringToBytes(s string) []byte { + p := unsafe.StringData(s) + b := unsafe.Slice(p, len(s)) + return b +} diff --git a/workers/check_expired_records.go b/workers/check_expired_records.go index daba5c8..a409aa0 100644 --- a/workers/check_expired_records.go +++ b/workers/check_expired_records.go @@ -7,6 +7,10 @@ import ( ) func checkExpiredRecords() { + if utils.Config.Get("mode").String() != "db" { + return + } + t := time.NewTicker(time.Second * 10) defer t.Stop()