package htmxmqtt import ( "time" "net/http" "os" "fmt" "context" "sync" "slices" "nhooyr.io/websocket" mqtt "github.com/eclipse/paho.mqtt.golang" ) type MQTTFormatFunc func(mqtt.Message) []byte type MQTTHandler struct { sync.Mutex format MQTTFormatFunc channels []chan mqtt.Message } func (handler *MQTTHandler) processMessage(client mqtt.Client, message mqtt.Message) { message.Ack() handler.Lock() defer handler.Unlock() remaining := make([]chan mqtt.Message, 0, len(handler.channels)) for _, channel := range(handler.channels) { select { case channel <- message: remaining = append(remaining, channel) default: os.Stderr.WriteString("Channel overflow\n") } } handler.channels = remaining } func (handler *MQTTHandler) addChannel(channel chan mqtt.Message) func() { handler.Lock() defer handler.Unlock() handler.channels = append(handler.channels, channel) return func() { handler.Lock() defer handler.Unlock() idx := slices.Index(handler.channels, channel) if idx < 0 { return } handler.channels[idx] = handler.channels[len(handler.channels)-1] handler.channels = handler.channels[:len(handler.channels)-1] } } type MQTTHandlerClient struct { mqtt.Client subscriptions map[*MQTTHandler]string subscribeTimeout time.Duration } func NewMQTTHandlerClient(broker string, username string, password string, id string) (*MQTTHandlerClient, error) { opts := mqtt.NewClientOptions() opts.AddBroker(broker) opts.SetClientID(id) opts.SetUsername(username) opts.SetPassword(password) opts.SetKeepAlive(2 * time.Second) opts.SetPingTimeout(1 * time.Second) client := mqtt.NewClient(opts) if token := client.Connect(); token.Wait() && token.Error() != nil { return nil, token.Error() } return &MQTTHandlerClient{ Client: client, subscriptions: map[*MQTTHandler]string{}, subscribeTimeout: 1*time.Second, }, nil } func (client *MQTTHandlerClient) NewHandler(subscription string, format MQTTFormatFunc) (*MQTTHandler, error) { handler := &MQTTHandler{ format: format, } sub_token := client.Subscribe(subscription, 0x00, handler.processMessage) timeout := sub_token.WaitTimeout(client.subscribeTimeout) if timeout == false || sub_token.Error() != nil { return nil, fmt.Errorf("Failed to subscribe to %s - %e", subscription, sub_token.Error()) } client.subscriptions[handler] = subscription return handler, nil } func (handler *MQTTHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { channel := make(chan mqtt.Message, 1) remove_channel := handler.addChannel(channel) defer remove_channel() conn, err := websocket.Accept(w, r, nil) if err != nil { os.Stderr.WriteString(fmt.Sprintf("websocket accept error: %s\n", err)) } defer conn.CloseNow() ctx, cancel_func := context.WithCancel(context.Background()) go func(conn *websocket.Conn, cancel_func context.CancelFunc) { for true { msg_type, data, err := conn.Read(ctx) if err != nil { os.Stderr.WriteString(fmt.Sprintf("websocket error: %s\n", err)) cancel_func() break } else { os.Stderr.WriteString(fmt.Sprintf("websocket data(%s): %s\n", msg_type, string(data))) } } }(conn, cancel_func) running := true done := ctx.Done() for running == true { select { case <- done: os.Stderr.WriteString("websocket context done") running = false case message := <- channel: text := handler.format(message) os.Stderr.WriteString(fmt.Sprintf("websocket write: %s\n", text)) err := conn.Write(ctx, websocket.MessageText, text) if err != nil { os.Stderr.WriteString(fmt.Sprintf("websocket write error: %s\n", err)) running = false } } } } func PayloadFormatFunc(template string) MQTTFormatFunc { return func(message mqtt.Message) []byte { return []byte(fmt.Sprintf(template, message.Payload())) } }