htmxmqtt/handler.go

158 lines
3.9 KiB
Go

2023-12-30 16:30:23 -07:00
package htmxmqtt
2023-12-30 14:33:52 -07:00
import (
"time"
"net/http"
2023-12-30 16:39:30 -07:00
"log/slog"
2023-12-30 14:33:52 -07:00
"fmt"
"context"
"sync"
"slices"
"nhooyr.io/websocket"
mqtt "github.com/eclipse/paho.mqtt.golang"
)
type MQTTFormatFunc func(mqtt.Message) []byte
2023-12-30 14:33:52 -07:00
type MQTTHandler struct {
sync.Mutex
2023-12-30 16:30:23 -07:00
format MQTTFormatFunc
channels []chan mqtt.Message
2023-12-30 14:33:52 -07:00
}
2023-12-30 16:30:23 -07:00
func (handler *MQTTHandler) processMessage(client mqtt.Client, message mqtt.Message) {
2023-12-30 14:33:52 -07:00
message.Ack()
handler.Lock()
defer handler.Unlock()
2023-12-30 16:30:23 -07:00
remaining := make([]chan mqtt.Message, 0, len(handler.channels))
for _, channel := range(handler.channels) {
2023-12-30 14:33:52 -07:00
select {
case channel <- message:
remaining = append(remaining, channel)
default:
2023-12-30 16:39:30 -07:00
slog.Warn("Channel overflow")
2023-12-30 14:33:52 -07:00
}
}
2023-12-30 16:30:23 -07:00
handler.channels = remaining
2023-12-30 14:33:52 -07:00
}
2023-12-30 16:30:23 -07:00
func (handler *MQTTHandler) addChannel(channel chan mqtt.Message) func() {
2023-12-30 14:33:52 -07:00
handler.Lock()
defer handler.Unlock()
2023-12-30 16:30:23 -07:00
handler.channels = append(handler.channels, channel)
2023-12-30 14:33:52 -07:00
return func() {
handler.Lock()
defer handler.Unlock()
2023-12-30 16:30:23 -07:00
idx := slices.Index(handler.channels, channel)
2023-12-30 14:33:52 -07:00
if idx < 0 {
return
}
2023-12-30 16:30:23 -07:00
handler.channels[idx] = handler.channels[len(handler.channels)-1]
handler.channels = handler.channels[:len(handler.channels)-1]
2023-12-30 14:33:52 -07:00
}
}
type MQTTHandlerClient struct {
mqtt.Client
2023-12-30 16:30:23 -07:00
subscriptions map[*MQTTHandler]string
subscribeTimeout time.Duration
2023-12-30 14:33:52 -07:00
}
func NewMQTTHandlerClient(broker string, username string, password string, id string) (*MQTTHandlerClient, error) {
opts := mqtt.NewClientOptions()
opts.AddBroker(broker)
opts.SetClientID(id)
opts.SetUsername(username)
opts.SetPassword(password)
opts.SetKeepAlive(2 * time.Second)
opts.SetPingTimeout(1 * time.Second)
client := mqtt.NewClient(opts)
if token := client.Connect(); token.Wait() && token.Error() != nil {
return nil, token.Error()
}
return &MQTTHandlerClient{
Client: client,
2023-12-30 16:30:23 -07:00
subscriptions: map[*MQTTHandler]string{},
subscribeTimeout: 1*time.Second,
2023-12-30 14:33:52 -07:00
}, nil
}
func (client *MQTTHandlerClient) NewHandler(subscription string, format MQTTFormatFunc) (*MQTTHandler, error) {
2023-12-30 14:33:52 -07:00
handler := &MQTTHandler{
2023-12-30 16:30:23 -07:00
format: format,
2023-12-30 14:33:52 -07:00
}
2023-12-30 16:30:23 -07:00
sub_token := client.Subscribe(subscription, 0x00, handler.processMessage)
2023-12-30 14:33:52 -07:00
2023-12-30 16:30:23 -07:00
timeout := sub_token.WaitTimeout(client.subscribeTimeout)
2023-12-30 14:33:52 -07:00
if timeout == false || sub_token.Error() != nil {
return nil, fmt.Errorf("Failed to subscribe to %s - %e", subscription, sub_token.Error())
}
2023-12-30 16:30:23 -07:00
client.subscriptions[handler] = subscription
2023-12-30 14:33:52 -07:00
return handler, nil
}
func (handler *MQTTHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
channel := make(chan mqtt.Message, 1)
2023-12-30 16:30:23 -07:00
remove_channel := handler.addChannel(channel)
2023-12-30 14:33:52 -07:00
defer remove_channel()
conn, err := websocket.Accept(w, r, nil)
if err != nil {
2023-12-30 16:39:30 -07:00
slog.Error("websocket accept error", "error", err)
return
} else {
slog.Info("new websocket connection", "addr", r.RemoteAddr)
2023-12-30 14:33:52 -07:00
}
defer conn.CloseNow()
ctx, cancel_func := context.WithCancel(context.Background())
go func(conn *websocket.Conn, cancel_func context.CancelFunc) {
for true {
msg_type, data, err := conn.Read(ctx)
if err != nil {
2023-12-30 16:39:30 -07:00
slog.Error("websocket error", "error", err)
2023-12-30 14:33:52 -07:00
cancel_func()
break
} else {
2023-12-30 16:39:30 -07:00
slog.Debug("websocket data", "type", msg_type, "data", string(data))
2023-12-30 14:33:52 -07:00
}
}
}(conn, cancel_func)
running := true
done := ctx.Done()
for running == true {
select {
case <- done:
2023-12-30 16:39:30 -07:00
slog.Debug("websocket context done")
2023-12-30 14:33:52 -07:00
running = false
case message := <- channel:
2023-12-30 16:30:23 -07:00
text := handler.format(message)
2023-12-30 16:39:30 -07:00
slog.Debug("websocket write", "data", text)
err := conn.Write(ctx, websocket.MessageText, text)
2023-12-30 14:33:52 -07:00
if err != nil {
2023-12-30 16:39:30 -07:00
slog.Error("websocket write error", "error", err)
2023-12-30 14:33:52 -07:00
running = false
}
}
}
2023-12-30 16:39:30 -07:00
slog.Info("closing websocket", "addr", r.RemoteAddr)
2023-12-30 14:33:52 -07:00
}
func PayloadFormatFunc(template string) MQTTFormatFunc {
return func(message mqtt.Message) []byte {
return []byte(fmt.Sprintf(template, message.Payload()))
}
}