From d385bbf906749577f48bb7833037a06d80a823c2 Mon Sep 17 00:00:00 2001 From: Guilherme Rugai Freire Date: Sun, 21 Jul 2024 20:15:25 -0300 Subject: [PATCH] only allow mails to addresses that contain the mail server domain there were some mails coming in that was not for any temp mail addr, it was like spam@otherdomain.com, but still reaching the database. This should prevent this. --- pkg/mail_server/main.go | 88 +++++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/pkg/mail_server/main.go b/pkg/mail_server/main.go index bbfadb8..ef2e0a9 100644 --- a/pkg/mail_server/main.go +++ b/pkg/mail_server/main.go @@ -8,6 +8,7 @@ import ( "log" "os" "strconv" + "strings" "time" "github.com/GRFreire/nthmail/pkg/mail_utils" @@ -16,7 +17,8 @@ import ( ) type Backend struct { - db *sql.DB + db *sql.DB + domain string } func (backend *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) { @@ -26,7 +28,8 @@ func (backend *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) { } return &Session{ - tx: tx, + tx: tx, + domain: backend.domain, }, nil } @@ -34,6 +37,16 @@ type Session struct { tx *sql.Tx from, rcpt string arrived_at int64 + domain string +} + +func get_addr_domain(addr string) string { + index := strings.Index(addr, "@") + if index < 0 { + return "" + } + + return addr[index+1:] } func (session *Session) AuthPlain(username, password string) error { @@ -48,6 +61,9 @@ func (session *Session) Mail(from string, opts *smtp.MailOptions) error { } func (session *Session) Rcpt(to string, opts *smtp.RcptOptions) error { + if get_addr_domain(to) != session.domain { + return errors.New("To addr domain is not available in this server") + } session.rcpt = to return nil @@ -55,33 +71,38 @@ func (session *Session) Rcpt(to string, opts *smtp.RcptOptions) error { func (session *Session) Data(reader io.Reader) error { defer session.tx.Rollback() - if bytes, err := io.ReadAll(reader); err != nil { + + bytes, err := io.ReadAll(reader) + if err != nil { return err - } else { - - stmt, err := session.tx.Prepare("INSERT INTO mails (arrived_at, rcpt_addr, from_addr, subject, data) VALUES (?, ?, ?, ?, ?)") - if err != nil { - println(err) - return err - } - defer stmt.Close() - - mail_obj, err := mail_utils.Parse_mail(bytes, true) - if err != nil { - return err - } - - _, err = stmt.Exec(session.arrived_at, session.rcpt, mail_obj.From, mail_obj.Subject, bytes) - if err != nil { - return err - } - - err = session.tx.Commit() - if err != nil { - return err - } - } + + mail_obj, err := mail_utils.Parse_mail(bytes, true) + if err != nil { + return err + } + + if get_addr_domain(mail_obj.To) != session.domain { + return errors.New("To addr domain is not available in this server") + } + + stmt, err := session.tx.Prepare("INSERT INTO mails (arrived_at, rcpt_addr, from_addr, subject, data) VALUES (?, ?, ?, ?, ?)") + if err != nil { + println(err) + return err + } + defer stmt.Close() + + _, err = stmt.Exec(session.arrived_at, session.rcpt, mail_obj.From, mail_obj.Subject, bytes) + if err != nil { + return err + } + + err = session.tx.Commit() + if err != nil { + return err + } + return nil } @@ -92,12 +113,6 @@ func (session *Session) Logout() error { } func Start(db *sql.DB) error { - backend := &Backend{ - db: db, - } - - server := smtp.NewServer(backend) - domain, exists := os.LookupEnv("MAIL_SERVER_DOMAIN") if !exists { domain = "localhost" @@ -115,6 +130,13 @@ func Start(db *sql.DB) error { port = 1025 } + backend := &Backend{ + db: db, + domain: domain, + } + + server := smtp.NewServer(backend) + server.Addr = fmt.Sprintf(":%d", port) server.Domain = domain server.WriteTimeout = 60 * time.Second