package main import ( "errors" "time" "net/http" "os" "os/signal" "syscall" "fmt" "context" "sync" "slices" "nhooyr.io/websocket" mqtt "github.com/eclipse/paho.mqtt.golang" ) type MQTTFormatFunc func(mqtt.Message) []byte type MQTTHandler struct { sync.Mutex Format MQTTFormatFunc Channels []chan mqtt.Message } func (handler *MQTTHandler) ProcessMessage(client mqtt.Client, message mqtt.Message) { message.Ack() handler.Lock() defer handler.Unlock() remaining := make([]chan mqtt.Message, 0, len(handler.Channels)) for _, channel := range(handler.Channels) { select { case channel <- message: remaining = append(remaining, channel) default: os.Stderr.WriteString("Channel overflow\n") } } handler.Channels = remaining } func (handler *MQTTHandler) AddChannel(channel chan mqtt.Message) func() { handler.Lock() defer handler.Unlock() handler.Channels = append(handler.Channels, channel) return func() { handler.Lock() defer handler.Unlock() idx := slices.Index(handler.Channels, channel) if idx < 0 { return } handler.Channels[idx] = handler.Channels[len(handler.Channels)-1] handler.Channels = handler.Channels[:len(handler.Channels)-1] } } type MQTTHandlerClient struct { mqtt.Client Subscriptions map[*MQTTHandler]string SubscribeTimeout time.Duration } func NewMQTTHandlerClient(broker string, username string, password string, id string) (*MQTTHandlerClient, error) { opts := mqtt.NewClientOptions() opts.AddBroker(broker) opts.SetClientID(id) opts.SetUsername(username) opts.SetPassword(password) opts.SetKeepAlive(2 * time.Second) opts.SetPingTimeout(1 * time.Second) client := mqtt.NewClient(opts) if token := client.Connect(); token.Wait() && token.Error() != nil { return nil, token.Error() } return &MQTTHandlerClient{ Client: client, Subscriptions: map[*MQTTHandler]string{}, SubscribeTimeout: 1*time.Second, }, nil } func (client *MQTTHandlerClient) NewHandler(subscription string, format MQTTFormatFunc) (*MQTTHandler, error) { handler := &MQTTHandler{ Format: format, } sub_token := client.Subscribe(subscription, 0x00, handler.ProcessMessage) timeout := sub_token.WaitTimeout(client.SubscribeTimeout) if timeout == false || sub_token.Error() != nil { return nil, fmt.Errorf("Failed to subscribe to %s - %e", subscription, sub_token.Error()) } client.Subscriptions[handler] = subscription return handler, nil } func (handler *MQTTHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { channel := make(chan mqtt.Message, 1) remove_channel := handler.AddChannel(channel) defer remove_channel() conn, err := websocket.Accept(w, r, nil) if err != nil { os.Stderr.WriteString(fmt.Sprintf("websocket accept error: %s\n", err)) } defer conn.CloseNow() ctx, cancel_func := context.WithCancel(context.Background()) go func(conn *websocket.Conn, cancel_func context.CancelFunc) { for true { msg_type, data, err := conn.Read(ctx) if err != nil { os.Stderr.WriteString(fmt.Sprintf("websocket error: %s\n", err)) cancel_func() break } else { os.Stderr.WriteString(fmt.Sprintf("websocket data(%s): %s\n", msg_type, string(data))) } } }(conn, cancel_func) running := true done := ctx.Done() for running == true { select { case <- done: os.Stderr.WriteString("websocket context done") running = false case message := <- channel: text := handler.Format(message) os.Stderr.WriteString(fmt.Sprintf("websocket write: %s\n", text)) err := conn.Write(ctx, websocket.MessageText, text) if err != nil { os.Stderr.WriteString(fmt.Sprintf("websocket write error: %s\n", err)) running = false } } } } func PayloadFormatFunc(template string) MQTTFormatFunc { return func(message mqtt.Message) []byte { return []byte(fmt.Sprintf(template, message.Payload())) } } func main() { handler_client, err := NewMQTTHandlerClient("tcp://localhost:1883", "", "", "htmx") if err != nil { panic(err) } handler_1, err := handler_client.NewHandler("test", PayloadFormatFunc(`

%s

`)) if err != nil { panic(err) } mux := http.NewServeMux() mux.Handle("/", http.FileServer(http.Dir("./site"))) mux.Handle("/ws", handler_1) server := &http.Server{ Handler: mux, Addr: ":8080", } sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGTERM, syscall.SIGKILL) go func(sigs chan os.Signal, server *http.Server) { <- sigs server.Close() }(sigs, server) err = server.ListenAndServe() if errors.Is(err, http.ErrServerClosed) == true { os.Stderr.WriteString("Server closed on signal\n") } else if err != nil { os.Stderr.WriteString(fmt.Sprintf("Server error: %s\n", err)) } else { os.Stderr.WriteString("Server closed\n") } }