import PartySocket from "partysocket";
import { usePartySocket } from "partysocket/react";
import { useEffect } from "react";
import { useDispatch, useStore } from "react-redux";
import { UnknownAction } from "redux";
import { ensureExhaustive } from "shared_frontend/src/utils";
import TokenHandler from "shared_utils/tokenHandler";
import privateApi from "../../features/API";
import { dashboardUtil } from "../../features/API/Dashboard/dashboard";
import { bulkAnswersUtil } from "../../features/API/bulkAnswer";
import { CompanyId, UserId } from "../../features/API/types";
import { NonStaffRequestThreadRT } from "../../features/MissionControl/types/unions";
import { RootState } from "../../store";
import {
  GeneralSocketIncomingMessage,
  SocketSubscriptionTopic,
} from "./generatedTypes";
import { SocketMessage } from "./types";

export class GeneralSocket {
  private static socket?: PartySocket;
  private static currentSubscriptions: SocketSubscriptionTopic[] = [];

  static setSocket = (socket: PartySocket) => {
    this.socket = socket;
    return;
  };

  static resubscribeToCurrentSubscriptions = () => {
    if (!this.socket) {
      throw new Error("socket is undefined");
    }
    for (const topic of this.currentSubscriptions) {
      const message: GeneralSocketIncomingMessage = {
        type: "subscribe",
        topic: topic,
      };
      this.socket.send(JSON.stringify(message));
    }
    return;
  };
  static socketSend = (message: GeneralSocketIncomingMessage) => {
    if (!this.socket) {
      throw new Error("socket is undefined");
    }

    if (message.type === "subscribe") {
      if (message.replaceExisting) {
        this.currentSubscriptions = this.currentSubscriptions.filter(
          (s) => s.type !== message.topic.key,
        );
      }
      this.currentSubscriptions.push(message.topic);
    }
    this.socket.send(JSON.stringify(message));
    return;
  };
}

export const useInitializeGeneralSocket = () => {
  const dispatch = useDispatch();
  const store = useStore();
  const socket = usePartySocket({
    host: window.ENV_VARIABLE_BACKEND_SOCKET_BASE_URL,
    room: "general",
    query: async () => {
      let token: string | null = null;
      while (!token) {
        token = await TokenHandler.getAccessToken();
        if (!token) {
          await new Promise((r) => setTimeout(r, 5000));
        }
      }
      return {
        token: token,
      };
    },
    onOpen: () => {
      // when cloud run disconnects after 5 minute timeout, we want new connection to resubscribe to existing subscriptions
      GeneralSocket.resubscribeToCurrentSubscriptions();
    },
    onMessage: (event) => {
      const socketMessage = JSON.parse(String(event.data)) as SocketMessage;
      const type = socketMessage.topicType;
      switch (type) {
        case "company-tag-invalidation":
        case "user-tag-invalidation":
          dispatch(privateApi.util.invalidateTags([socketMessage.payload.tag]));
          break;
        case "questionnaire-task": {
          const thread = NonStaffRequestThreadRT.check(socketMessage.payload);
          dispatch(
            dashboardUtil.updateQueryData(
              "getQuestionnaireThread",
              thread.relationships.task.id,
              (prev) => {
                prev.data = thread;
              },
            ) as unknown as UnknownAction,
          );
          break;
        }
        case "bulk-answer-request": {
          // we include pagination params in the query so we need to check for updates on all pages
          const invalidatedQueries = bulkAnswersUtil.selectInvalidatedBy(
            store.getState() as RootState,
            [{ type: "BulkAnswer", id: socketMessage.payload.requestId }],
          );
          invalidatedQueries.forEach((q) =>
            dispatch(
              bulkAnswersUtil.updateQueryData(
                "getRequest",
                q.originalArgs,
                (prev) => {
                  const existingQuestionIndex = prev.data.answers.findIndex(
                    (a) => a.id === socketMessage.payload.answer.id,
                  );
                  if (existingQuestionIndex !== -1) {
                    Object.assign(
                      prev.data.answers[existingQuestionIndex],
                      socketMessage.payload.answer,
                    );
                  }
                },
              ) as unknown as UnknownAction,
            ),
          );
          break;
        }
        case "bulk-answer-status": {
          // we include pagination params in the query so we need to check for updates on all pages
          const invalidatedQueries = bulkAnswersUtil.selectInvalidatedBy(
            store.getState() as RootState,
            [{ type: "BulkAnswer", id: socketMessage.payload.requestId }],
          );
          invalidatedQueries.forEach((q) =>
            dispatch(
              bulkAnswersUtil.updateQueryData(
                "getRequest",
                q.originalArgs,
                (prev) => {
                  prev.data.status = socketMessage.payload.status;
                },
              ) as unknown as UnknownAction,
            ),
          );
          dispatch(dashboardUtil.invalidateTags(["DashboardQuestionnaire"]));
          break;
        }
        default:
          ensureExhaustive(type);
          break;
      }
    },
  });
  GeneralSocket.setSocket(socket);
};

// Use this hook to subscribe to different group in the channels layer
export const useTagInvalidationSocketSubscribe = (
  companyId?: CompanyId,
  userId?: UserId,
) => {
  useEffect(() => {
    if (companyId) {
      GeneralSocket.socketSend({
        type: "subscribe",
        replaceExisting: true,
        topic: {
          type: "company-tag-invalidation",
          key: companyId,
        },
      });
    }
    if (userId) {
      GeneralSocket.socketSend({
        type: "subscribe",
        replaceExisting: true,
        topic: {
          type: "user-tag-invalidation",
          key: userId.toString(),
        },
      });
    }
  }, [companyId, userId]);
};
