only allow mails to addresses that contain the mail server domain

there were some mails coming in that was not for any temp mail addr,
it was like spam@otherdomain.com, but still reaching the database. This
should prevent this.
This commit is contained in:
Guilherme Rugai Freire 2024-07-21 20:15:25 -03:00
parent 7fe3ed30aa
commit d385bbf906
No known key found for this signature in database
GPG Key ID: AC1D9B6E48E16AC1

View File

@ -8,6 +8,7 @@ import (
"log" "log"
"os" "os"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/GRFreire/nthmail/pkg/mail_utils" "github.com/GRFreire/nthmail/pkg/mail_utils"
@ -17,6 +18,7 @@ import (
type Backend struct { type Backend struct {
db *sql.DB db *sql.DB
domain string
} }
func (backend *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) { func (backend *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) {
@ -27,6 +29,7 @@ func (backend *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) {
return &Session{ return &Session{
tx: tx, tx: tx,
domain: backend.domain,
}, nil }, nil
} }
@ -34,6 +37,16 @@ type Session struct {
tx *sql.Tx tx *sql.Tx
from, rcpt string from, rcpt string
arrived_at int64 arrived_at int64
domain string
}
func get_addr_domain(addr string) string {
index := strings.Index(addr, "@")
if index < 0 {
return ""
}
return addr[index+1:]
} }
func (session *Session) AuthPlain(username, password string) error { func (session *Session) AuthPlain(username, password string) error {
@ -48,6 +61,9 @@ func (session *Session) Mail(from string, opts *smtp.MailOptions) error {
} }
func (session *Session) Rcpt(to string, opts *smtp.RcptOptions) error { func (session *Session) Rcpt(to string, opts *smtp.RcptOptions) error {
if get_addr_domain(to) != session.domain {
return errors.New("To addr domain is not available in this server")
}
session.rcpt = to session.rcpt = to
return nil return nil
@ -55,9 +71,20 @@ func (session *Session) Rcpt(to string, opts *smtp.RcptOptions) error {
func (session *Session) Data(reader io.Reader) error { func (session *Session) Data(reader io.Reader) error {
defer session.tx.Rollback() defer session.tx.Rollback()
if bytes, err := io.ReadAll(reader); err != nil {
bytes, err := io.ReadAll(reader)
if err != nil {
return err return err
} else { }
mail_obj, err := mail_utils.Parse_mail(bytes, true)
if err != nil {
return err
}
if get_addr_domain(mail_obj.To) != session.domain {
return errors.New("To addr domain is not available in this server")
}
stmt, err := session.tx.Prepare("INSERT INTO mails (arrived_at, rcpt_addr, from_addr, subject, data) VALUES (?, ?, ?, ?, ?)") stmt, err := session.tx.Prepare("INSERT INTO mails (arrived_at, rcpt_addr, from_addr, subject, data) VALUES (?, ?, ?, ?, ?)")
if err != nil { if err != nil {
@ -66,11 +93,6 @@ func (session *Session) Data(reader io.Reader) error {
} }
defer stmt.Close() defer stmt.Close()
mail_obj, err := mail_utils.Parse_mail(bytes, true)
if err != nil {
return err
}
_, err = stmt.Exec(session.arrived_at, session.rcpt, mail_obj.From, mail_obj.Subject, bytes) _, err = stmt.Exec(session.arrived_at, session.rcpt, mail_obj.From, mail_obj.Subject, bytes)
if err != nil { if err != nil {
return err return err
@ -81,7 +103,6 @@ func (session *Session) Data(reader io.Reader) error {
return err return err
} }
}
return nil return nil
} }
@ -92,12 +113,6 @@ func (session *Session) Logout() error {
} }
func Start(db *sql.DB) error { func Start(db *sql.DB) error {
backend := &Backend{
db: db,
}
server := smtp.NewServer(backend)
domain, exists := os.LookupEnv("MAIL_SERVER_DOMAIN") domain, exists := os.LookupEnv("MAIL_SERVER_DOMAIN")
if !exists { if !exists {
domain = "localhost" domain = "localhost"
@ -115,6 +130,13 @@ func Start(db *sql.DB) error {
port = 1025 port = 1025
} }
backend := &Backend{
db: db,
domain: domain,
}
server := smtp.NewServer(backend)
server.Addr = fmt.Sprintf(":%d", port) server.Addr = fmt.Sprintf(":%d", port)
server.Domain = domain server.Domain = domain
server.WriteTimeout = 60 * time.Second server.WriteTimeout = 60 * time.Second