import { isPatientAndGroup } from '@om1/platform-assistant-ui/src/components/query/nodes/And'
import { NodeTypes } from '@om1/platform-assistant-ui/src/components/query/QueryBlock'
import { useMutation, useQuery } from '@tanstack/react-query'
import { useRef, useState } from 'react'
import { falconApiConfig } from '../../falcon-api-config'
import {
    AndDTO_Output,
    CohortCreateDTO,
    CohortDTO,
    DateAwareFilterDTO_Output,
    DateWindowFilterDTO_Output,
    ExceptDTO_Output,
    ExplorerCohortsService,
    ExplorerTasksService,
    OpenAPI,
    OrDTO_Output,
    RelativeDateFilterDTO_Output,
    RelativeFollowUpFilterDTO_Output
} from '../client'

export const QUERY_KEYS = {
    cohort: 'cohort',
    cohortSize: 'cohortSize'
} as const

export function useCohort(params: Parameters<typeof ExplorerCohortsService.getExplorerCohortsCohortIdGet>[0]) {
    const [cohort, setCohort] = useState<CohortDTO>()
    const [cohortSizeUpdate, setCohortSizeUpdate] = useState(0)
    const { error, isLoading, refetch } = useQuery({
        queryKey: [QUERY_KEYS.cohort, params.cohortId],
        queryFn: () =>
            ExplorerCohortsService.getExplorerCohortsCohortIdGet({
                cohortId: params.cohortId
            }),
        onSuccess: (data) => {
            setCohort(data)
            setCohortSizeUpdate((prev) => prev + 1)
        }
    })

    const wsRef = useRef<WebSocket | null>(null)

    useQuery({
        queryKey: [QUERY_KEYS.cohortSize, cohort?.id, cohortSizeUpdate],
        queryFn: () => ExplorerTasksService.cohortUpdateSizeExplorerTasksCohortCohortIdUpdateSizeGet({ cohortId: cohort?.id || '' }),
        enabled: Boolean(cohort?.id),
        onSuccess: (data) => {
            if (data?.taskId) {
                const websocketUrl = `${falconApiConfig.falconApiUrl?.replace(new RegExp(/^http/), 'ws')}/tasks/ws/${data.taskId}`
                const ws = new WebSocket(websocketUrl)
                wsRef.current = ws

                ws.onopen = () => {
                    ws.send(JSON.stringify({ token: String(OpenAPI.TOKEN) }))
                }

                ws.onmessage = (event) => {
                    const newMessage = JSON.parse(event.data)
                    if (newMessage.status?.toLowerCase() === 'success' && cohort) {
                        if (newMessage?.result?.cohortSize) {
                            setCohort({ ...cohort, cohortSize: newMessage?.result?.cohortSize })
                        }
                    }
                }

                ws.onclose = () => {
                    console.log('Update size websocket connection closed')
                }

                return () => {
                    ws.close()
                }
            }
        }
    })

    const findSelectedBlock = (filter: NodeTypes | null | undefined, id: number, depth: number = 0): NodeTypes | null => {
        if (!filter) return null

        const isAnd = (filter: NodeTypes): filter is AndDTO_Output => filter.type === 'AndDTO'
        const isOr = (filter: NodeTypes): filter is OrDTO_Output => filter.type === 'OrDTO'
        const isExcept = (filter: NodeTypes): filter is ExceptDTO_Output => filter.type === 'ExceptDTO'
        const isDateAwareFilter = (filter: NodeTypes): filter is DateAwareFilterDTO_Output => filter.type === 'DateAwareFilterDTO'
        const isRelativeDateFilter = (filter: NodeTypes): filter is RelativeDateFilterDTO_Output => filter.type === 'RelativeDateFilterDTO'
        const isDateWindowFilter = (filter: NodeTypes): filter is DateWindowFilterDTO_Output => filter.type === 'DateWindowFilterDTO'
        const isRelativeFollowUpFilter = (filter: NodeTypes): filter is RelativeFollowUpFilterDTO_Output =>
            filter.type === 'RelativeFollowUpFilterDTO'

        if (isAnd(filter)) {
            if (isPatientAndGroup(filter) && filter.id === id) {
                return filter
            }
            for (const andFilter of filter.and) {
                if (andFilter.id === id) {
                    return andFilter
                }
                const node = findSelectedBlock(andFilter, id, depth + 1)
                if (node) return node
            }
        } else if (isOr(filter)) {
            for (const subFilter of filter.or) {
                if (subFilter.id === id) {
                    return subFilter
                }
                const node = findSelectedBlock(subFilter, id, depth + 1)
                if (node) return node
            }
        } else if (isExcept(filter)) {
            for (const subFilter of filter.except) {
                if (subFilter.id === id) {
                    return subFilter
                }
                const node = findSelectedBlock(subFilter, id, depth + 1)
                if (node) return node
            }
        } else if (isDateAwareFilter(filter) && filter.id === id) {
            return filter
        } else if (isRelativeDateFilter(filter)) {
            if (filter.id === id) {
                return filter
            } else {
                const subjectFilterNode = findSelectedBlock(filter.subjectFilter, id, depth + 1)
                if (subjectFilterNode) {
                    return subjectFilterNode
                } else {
                    const referenceFilterNode = findSelectedBlock(filter.referenceFilter, id, depth + 1)
                    if (referenceFilterNode) {
                        return referenceFilterNode
                    } else {
                        return null
                    }
                }
            }
        } else if (isDateWindowFilter(filter) && filter.id === id) {
            return filter
        } else if (isRelativeFollowUpFilter(filter)) {
            if (filter.id === id) {
                return filter
            } else {
                const baselineNode = findSelectedBlock(filter.baseline, id, depth + 1)
                if (baselineNode) {
                    return baselineNode
                } else {
                    return findSelectedBlock(filter.followUp, id, depth + 1)
                }
            }
        }
        return null
    }

    return {
        cohort: cohort,
        error,
        isLoading,
        refetchCohort: refetch,
        isRefreshing: false,
        findBlock: (id: number): NodeTypes | null => {
            const query = cohort?.query
            if (!query) return null

            for (const filter of query.filters || []) {
                const found = findSelectedBlock(filter, id, 0)
                if (found) {
                    return found // Return immediately if a block is found
                }
            }

            return null // Return null if no block is found
        }
    }
}

export function useCreateCohort() {
    return useMutation({
        mutationFn: (data: CohortCreateDTO) =>
            ExplorerCohortsService.createExplorerCohortsPost({
                requestBody: data
            })
    })
}
