mirror of
https://github.com/GRFreire/nthmail.git
synced 2026-01-09 04:49:39 +00:00
only allow mails to addresses that contain the mail server domain
there were some mails coming in that was not for any temp mail addr, it was like spam@otherdomain.com, but still reaching the database. This should prevent this.
This commit is contained in:
parent
7fe3ed30aa
commit
d385bbf906
@ -8,6 +8,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/GRFreire/nthmail/pkg/mail_utils"
|
||||
@ -16,7 +17,8 @@ import (
|
||||
)
|
||||
|
||||
type Backend struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
domain string
|
||||
}
|
||||
|
||||
func (backend *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) {
|
||||
@ -26,7 +28,8 @@ func (backend *Backend) NewSession(c *smtp.Conn) (smtp.Session, error) {
|
||||
}
|
||||
|
||||
return &Session{
|
||||
tx: tx,
|
||||
tx: tx,
|
||||
domain: backend.domain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -34,6 +37,16 @@ type Session struct {
|
||||
tx *sql.Tx
|
||||
from, rcpt string
|
||||
arrived_at int64
|
||||
domain string
|
||||
}
|
||||
|
||||
func get_addr_domain(addr string) string {
|
||||
index := strings.Index(addr, "@")
|
||||
if index < 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return addr[index+1:]
|
||||
}
|
||||
|
||||
func (session *Session) AuthPlain(username, password string) error {
|
||||
@ -48,6 +61,9 @@ func (session *Session) Mail(from string, opts *smtp.MailOptions) error {
|
||||
}
|
||||
|
||||
func (session *Session) Rcpt(to string, opts *smtp.RcptOptions) error {
|
||||
if get_addr_domain(to) != session.domain {
|
||||
return errors.New("To addr domain is not available in this server")
|
||||
}
|
||||
session.rcpt = to
|
||||
|
||||
return nil
|
||||
@ -55,33 +71,38 @@ func (session *Session) Rcpt(to string, opts *smtp.RcptOptions) error {
|
||||
|
||||
func (session *Session) Data(reader io.Reader) error {
|
||||
defer session.tx.Rollback()
|
||||
if bytes, err := io.ReadAll(reader); err != nil {
|
||||
|
||||
bytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
|
||||
stmt, err := session.tx.Prepare("INSERT INTO mails (arrived_at, rcpt_addr, from_addr, subject, data) VALUES (?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
println(err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
mail_obj, err := mail_utils.Parse_mail(bytes, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(session.arrived_at, session.rcpt, mail_obj.From, mail_obj.Subject, bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
mail_obj, err := mail_utils.Parse_mail(bytes, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if get_addr_domain(mail_obj.To) != session.domain {
|
||||
return errors.New("To addr domain is not available in this server")
|
||||
}
|
||||
|
||||
stmt, err := session.tx.Prepare("INSERT INTO mails (arrived_at, rcpt_addr, from_addr, subject, data) VALUES (?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
println(err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(session.arrived_at, session.rcpt, mail_obj.From, mail_obj.Subject, bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -92,12 +113,6 @@ func (session *Session) Logout() error {
|
||||
}
|
||||
|
||||
func Start(db *sql.DB) error {
|
||||
backend := &Backend{
|
||||
db: db,
|
||||
}
|
||||
|
||||
server := smtp.NewServer(backend)
|
||||
|
||||
domain, exists := os.LookupEnv("MAIL_SERVER_DOMAIN")
|
||||
if !exists {
|
||||
domain = "localhost"
|
||||
@ -115,6 +130,13 @@ func Start(db *sql.DB) error {
|
||||
port = 1025
|
||||
}
|
||||
|
||||
backend := &Backend{
|
||||
db: db,
|
||||
domain: domain,
|
||||
}
|
||||
|
||||
server := smtp.NewServer(backend)
|
||||
|
||||
server.Addr = fmt.Sprintf(":%d", port)
|
||||
server.Domain = domain
|
||||
server.WriteTimeout = 60 * time.Second
|
||||
|
||||
Loading…
Reference in New Issue
Block a user