// S3FileViewer.tsx
import React, { useEffect, useState } from "react";
import {
  fetchFiles,
  deleteFile,
  getDownloadUrl,
  S3File,
} from "../../../../../components/S3/S3Utils";
import {
  Paper,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
  Grid,
  Tabs,
  Tab,
  Chip,
  CircularProgress,
  FormControl,
  Autocomplete,
  Checkbox,
  IconButton,
  Toolbar,
  Tooltip,
  Typography,
  Dialog,
  DialogActions,
  DialogContent,
  DialogContentText,
  DialogTitle,
  Button,
  Box,
  List,
  ListItemButton,
  ListItemIcon,
  ListItemText,
  AutocompleteChangeReason,
  AutocompleteChangeDetails,
  Alert,
  TextField,
} from "@mui/material";
import DeleteIcon from "@mui/icons-material/Delete";
import DownloadIcon from "@mui/icons-material/Download";
import { readJsonFromS3 } from "../../../../../components/S3/S3Utils";
import Plot from "react-plotly.js";
import { Datum } from "plotly.js";

interface StatisticalAnalysisProps {
  [key: string]: any;
}

interface CrossTabPlotsProps {
  [key: string]: any;
}

interface ApiResponse {
  merged_report?: string | null;
  input_file_name?: string | null;
  input_model_name?: string | null;
  out_location?: string | null;
  out_log_id?: string | null;
  out_generic_model_id?: number | null;
  crosstab_plots?: { [key: string]: CrossTabPlotsProps };
  statistical_analysis?: { [key: string]: StatisticalAnalysisProps };
  warnings?: string[] | null;
  error?: string | null;
  status?: string | null;
}

interface InferenceResponseViewerProps {
  refresh: any;
  bucketName: string;
  pathPrefix: string;
  serviceName: string;
}

const InferenceResponseViewer: React.FC<InferenceResponseViewerProps> = ({
  refresh,
  bucketName,
  pathPrefix,
  serviceName,
}) => {
  const [files, setFiles] = useState<S3File[]>([]);

  const [isLoading, setIsLoading] = useState(false);

  const [activeTab, setActiveTab] = React.useState<number>(0);

  const handleChange = (event: React.SyntheticEvent, newValue: number) => {
    setActiveTab(newValue);
  };

  useEffect(() => {
    const loadFiles = async () => {
      try {
        const fetchedFiles = await fetchFiles(
          (bucketName = bucketName),
          (pathPrefix = pathPrefix + "/responses/" + serviceName),
          ["json"]
        );
        setFiles(fetchedFiles);
      } catch (error) {
        console.error("Failed to fetch files:", error);
      }
    };
    loadFiles();
  }, [refresh, bucketName, pathPrefix]);

  const [selectedIndex, setSelectedIndex] = React.useState(0);
  const handleListItemClick = (
    event: React.MouseEvent<HTMLDivElement, MouseEvent>,
    index: number
  ) => {
    setSelectedIndex(index);
  };

  const [apiResponse, setApiResponse] = useState<ApiResponse>(
    files.length > 0
      ? (readJsonFromS3(bucketName, files[selectedIndex].Key) as ApiResponse)
      : ({} as ApiResponse)
  );

  useEffect(() => {
    const fetchData = async () => {
      try {
        if (selectedIndex >= 0 && selectedIndex < files.length) {
          setIsLoading(true);
          const response = await readJsonFromS3(
            bucketName,
            files[selectedIndex].Key
          );
          setApiResponse(response as ApiResponse);
        }
      } catch (error) {
        console.error("Got an error: ", error);
      }
      setIsLoading(false);
    };

    fetchData();
  }, [selectedIndex, files]);

  const handleDownload = async (key: string) => {
    try {
      const url = await getDownloadUrl(bucketName, key);
      window.open(url, "_blank");
    } catch (error) {
      console.error("Failed to download file:", error);
    }
  };

  /// Stacked Histogram

  const [stackedHistogramColumn, setStackedHistogramColumn] = useState<
    string | null
  >(null);

  const handleHistogramSelection = (
    event: React.ChangeEvent<{}>,
    value: string | null,
    reason: AutocompleteChangeReason,
    details?: AutocompleteChangeDetails<string> | undefined
  ) => {
    const newValue = value ? value : null;
    setStackedHistogramColumn(newValue);
  };

  interface StackedHistogramProps {
    crosstab_plots?: { [key: string]: CrossTabPlotsProps };
    key_name: string | null;
  }

  const StackedHistogram: React.FC<StackedHistogramProps> = ({
    crosstab_plots,
    key_name,
  }) => {
    if (!crosstab_plots || !key_name) {
      return <Typography variant="h6">Stacked histogram not found</Typography>;
    }

    const data = crosstab_plots[key_name];

    const xValues = Object.keys(data);
    const yValues = Array.from(
      new Set(Object.values(data).flatMap((obj) => Object.keys(obj)))
    );
    const traces = yValues.map((y) => ({
      x: xValues,
      y: xValues.map((x) => data[x][y] || 0),
      name: y,
      type: "bar" as const,
    }));

    return (
      <div>
        <Plot
          data={traces}
          layout={{
            barmode: "stack",
            title: "Stacked Histogram",
            xaxis: { title: key_name },
            yaxis: { title: "Count" },
          }}
        />
      </div>
    );
  };

  /// Crosstab Heatmap

  const [crosstabHeatmapColumn, setCrosstabHeatmapColumn] = useState<
    string | null
  >(null);

  const handleCrosstabSelection = (
    event: React.ChangeEvent<{}>,
    value: string | null,
    reason: AutocompleteChangeReason,
    details?: AutocompleteChangeDetails<string> | undefined
  ) => {
    const newValue = value ? value : null;
    setCrosstabHeatmapColumn(newValue);
  };

  interface CrosstabHeatmapProps {
    crosstab_plots?: { [key: string]: CrossTabPlotsProps };
    key_name: string | null;
  }

  const CrosstabHeatmap: React.FC<CrosstabHeatmapProps> = ({
    crosstab_plots,
    key_name,
  }) => {
    if (!crosstab_plots || !key_name) {
      return <Typography variant="h6">Heatmap not found</Typography>;
    }

    const data = crosstab_plots[key_name];

    const xValues = Object.keys(data);
    const yValues = Array.from(
      new Set(Object.values(data).flatMap((obj) => Object.keys(obj)))
    );
    const zValues = yValues.map((y) => xValues.map((x) => data[x][y] || 0));

    return (
      <div>
        <Plot
          data={[
            {
              x: xValues,
              y: yValues,
              z: zValues,
              type: "heatmap",
              colorscale: "Viridis",
            },
          ]}
          layout={{
            title: `Crosstab Heatmap for ${key_name}`,
            xaxis: { title: key_name, tickvals: xValues, ticktext: xValues },
            yaxis: { title: "Target Label" },
          }}
        />
      </div>
    );
  };

  /// Pie Charts

  const [pieChartColumn, setPieChartColumn] = useState<string | null>(null);

  const handlePieChartSelection = (
    event: React.ChangeEvent<{}>,
    value: string | null,
    reason: AutocompleteChangeReason,
    details?: AutocompleteChangeDetails<string> | undefined
  ) => {
    const newValue = value ? value : null;
    setPieChartColumn(newValue);
  };

  interface PieChartProps {
    crosstab_plots?: {
      [key: string]: { [x: string]: { [y: string]: number } };
    };
    key_name: string | null;
  }

  const PieChart: React.FC<PieChartProps> = ({ crosstab_plots, key_name }) => {
    if (!crosstab_plots || !key_name) {
      return <Typography variant="h6">Pie chart not found</Typography>;
    }

    const data = crosstab_plots[key_name];

    // Aggregate counts for each 'y' category
    const yData: Record<string, number> = {};
    Object.values(data).forEach((yVals) => {
      Object.entries(yVals).forEach(([y, count]) => {
        yData[y] = (yData[y] || 0) + count;
      });
    });

    const labels = Object.keys(yData);
    const values = Object.values(yData);

    return (
      <div>
        <Plot
          data={[
            {
              labels: labels,
              values: values,
              type: "pie",
              textinfo: "label+percent",
              textposition: "outside",
              automargin: true,
            },
          ]}
          layout={{
            title: `Pie Chart for ${key_name}`,
            showlegend: false,
          }}
          config={{
            responsive: true,
          }}
        />
      </div>
    );
  };

  /// Statistical Analysis Table

  const [statsTableColumn, setStatsTableColumn] = useState<string | null>(null);

  const handleStatsTableSelection = (
    event: React.ChangeEvent<{}>,
    value: string | null,
    reason: AutocompleteChangeReason,
    details?: AutocompleteChangeDetails<string> | undefined
  ) => {
    const newValue = value ? value : null;
    setStatsTableColumn(newValue);
  };

  interface StatProbs {
    statistical_analysis?: { [key: string]: StatisticalAnalysisProps };
    key_name: string | null;
  }

  const StatisticsTable: React.FC<StatProbs> = ({
    statistical_analysis,
    key_name,
  }) => {
    if (!statistical_analysis || !key_name) {
      return (
        <Typography variant="h6">Statistical Analysis Not Found!</Typography>
      );
    }

    const headers = [
      "mean",
      "median",
      "minimum",
      "maximum",
      "std",
      "percentile_25",
      "percentile_50",
      "percentile_75",
      "count",
      // "missing",
      // "missing_percent",
    ];

    return (
      <TableContainer
        component={Paper}
        sx={{ maxWidth: 1000, margin: "auto", mt: 4 }}
      >
        <Table aria-label="statistics table">
          <TableHead>
            <TableRow>
              <TableCell>Variable</TableCell>
              {Object.keys(statistical_analysis[headers[0]][key_name]).map(
                (variable) => (
                  <TableCell key={variable} align="right">
                    {variable}
                  </TableCell>
                )
              )}
            </TableRow>
          </TableHead>
          <TableBody>
            {headers.map((header) => (
              <TableRow key={header}>
                <TableCell component="th" scope="row">
                  {header}
                </TableCell>
                {Object.keys(statistical_analysis[header][key_name]).map(
                  (variable) => (
                    <TableCell key={header} align="right">
                      {statistical_analysis[header][key_name][variable] !==
                      undefined
                        ? parseFloat(
                            statistical_analysis[header][key_name][variable]
                          ).toFixed(2)
                        : "N/A"}
                    </TableCell>
                  )
                )}
              </TableRow>
            ))}
          </TableBody>
        </Table>
      </TableContainer>
    );
  };

  /// Sankey Diagram

  const [sankeyDiagramColumn, setSankeyDiagramColumn] = useState<string | null>(
    null
  );

  const handleSankeySelection = (
    event: React.ChangeEvent<{}>,
    value: string | null,
    reason: AutocompleteChangeReason,
    details?: AutocompleteChangeDetails<string> | undefined
  ) => {
    const newValue = value ? value : null;
    setSankeyDiagramColumn(newValue);
  };

  interface SankeyDiagramProps {
    crosstab_plots?: {
      [key: string]: { [x: string]: { [y: string]: number } };
    };
    key_name: string | null;
  }

  const SankeyDiagram: React.FC<SankeyDiagramProps> = ({
    crosstab_plots,
    key_name,
  }) => {
    if (!crosstab_plots || !key_name) {
      return <Typography variant="h6">Sankey diagram not found</Typography>;
    }

    const data = crosstab_plots[key_name];

    const nodes: string[] = [];
    const links: { source: number; target: number; value: number }[] = [];

    const nodeMap: { [key: string]: number } = {};
    let nodeCount = 0;

    const getNodeID = (nodeName: string): number => {
      if (!nodeMap[nodeName]) {
        nodeMap[nodeName] = nodeCount;
        nodes.push(nodeName);
        nodeCount += 1;
      }
      return nodeMap[nodeName];
    };

    for (const pc in data) {
      for (const texture in data[pc]) {
        const count = data[pc][texture];
        const sourceID = getNodeID(`${pc}`);
        const targetID = getNodeID(`${texture}`);
        if (count > 0) {
          links.push({
            source: sourceID,
            target: targetID,
            value: count,
          });
        }
      }
    }

    return (
      <div>
        <Plot
          data={[
            {
              type: "sankey",
              orientation: "h",
              node: {
                pad: 15,
                thickness: 20,
                line: {
                  color: "black",
                  width: 0.5,
                },
                label: nodes,
              },
              link: {
                source: links.map((link) => link.source),
                target: links.map((link) => link.target),
                value: links.map((link) => link.value),
              },
            },
          ]}
          layout={{
            title: `Sankey Diagram for ${key_name}`,
            font: {
              size: 10,
            },
          }}
        />
      </div>
    );
  };

  /// Sunburst Plots

  const [sunburstPlotColumn, setSunburstPlotColumn] = useState<string | null>(
    null
  );

  const handleSunburstSelection = (
    event: React.ChangeEvent<{}>,
    value: string | null,
    reason: AutocompleteChangeReason,
    details?: AutocompleteChangeDetails<string> | undefined
  ) => {
    const newValue = value ? value : null;
    setSunburstPlotColumn(newValue);
  };

  interface SunburstPlotProps {
    crosstab_plots?: {
      [key: string]: { [x: string]: { [y: string]: number } };
    };
    key_name: string | null;
  }

  const SunburstPlot: React.FC<SunburstPlotProps> = ({
    crosstab_plots,
    key_name,
  }) => {
    if (!crosstab_plots || !key_name) {
      return <Typography variant="h6">Sunburst plot not found</Typography>;
    }

    const data = crosstab_plots[key_name];

    const labels: string[] = [];
    const parents: string[] = [];
    const values: number[] = [];

    Object.entries(data).forEach(([point, prediction], i) => {
      labels.push(point);
      parents.push("");
      values.push(Object.values(prediction).reduce((a, b) => a + b, 0));

      Object.entries(prediction).forEach(([label, count]) => {
        labels.push(`${i}_${label}`);
        parents.push(point);
        values.push(count);
      });
    });

    return (
      <div>
        <Plot
          data={[
            {
              type: "sunburst",
              labels: labels,
              parents: parents,
              values: values,
              branchvalues: "total",
            },
          ]}
          layout={{
            title: `Sunburst Plot for ${key_name}`,
            font: {
              size: 10,
            },
          }}
        />
      </div>
    );
  };

  return (
    <div>
      <Grid container spacing={3}>
        <Grid item xs={12} sm={6}>
          <Paper style={{ maxHeight: 400, overflow: "auto" }}>
            <List
              component="nav"
              aria-label="main mailbox foldersfinished-jobs"
              title="List of Inference Jobs"
            >
              {files.map((file, index) => (
                <ListItemButton
                  selected={selectedIndex === index}
                  onClick={(event) => handleListItemClick(event, index)}
                >
                  <ListItemText
                    primary={file.Key.replace(
                      pathPrefix + "/responses/" + serviceName + "/",
                      ""
                    )}
                  />
                </ListItemButton>
              ))}
            </List>
          </Paper>
        </Grid>
        <Grid item xs={12} sm={6}>
          <TableContainer component={Paper}>
            <TableBody>
              <TableRow>
                <TableCell>Test Dataset</TableCell>
                <TableCell>
                  {apiResponse.input_file_name?.substring(
                    apiResponse.input_file_name?.lastIndexOf("/") + 1
                  )}
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Model Name</TableCell>
                <TableCell>
                  {apiResponse.input_model_name?.substring(
                    apiResponse.input_model_name?.lastIndexOf("/") + 1
                  )}
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Out S3 File Name</TableCell>
                <TableCell>
                  {apiResponse.out_location?.substring(
                    apiResponse.out_location?.lastIndexOf("/") + 1
                  )}
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Out Log ID</TableCell>
                <TableCell>
                  {apiResponse.out_log_id ? apiResponse.out_log_id : "N/A"}
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Out LithoLens Generic Model ID</TableCell>
                <TableCell>
                  {apiResponse.out_generic_model_id
                    ? apiResponse.out_generic_model_id
                    : "N/A"}
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Warnings</TableCell>
                <TableCell>
                  {apiResponse.warnings
                    ? apiResponse.warnings.length > 0
                      ? apiResponse.warnings
                      : "N/A"
                    : "N/A"}
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Error</TableCell>
                <TableCell>
                  {apiResponse.error ? apiResponse.error : "N/A"}
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Download Predictions</TableCell>
                <TableCell>
                  <Tooltip title="Download">
                    <IconButton
                      onClick={() =>
                        handleDownload(
                          apiResponse.out_location
                            ? apiResponse.out_location.replace(
                                "s3://" + bucketName + "/",
                                ""
                              )
                            : ""
                        )
                      }
                      disabled={
                        apiResponse.out_location === null ||
                        apiResponse.out_location === undefined
                      }
                      aria-label="download"
                    >
                      <DownloadIcon />
                    </IconButton>
                  </Tooltip>
                </TableCell>
              </TableRow>

              <TableRow>
                <TableCell>Status</TableCell>
                <TableCell>
                  <Chip
                    label={
                      apiResponse.status === "requested"
                        ? "requested"
                        : apiResponse.status === "failed"
                        ? "failed"
                        : "done"
                    }
                    color={
                      apiResponse.status === "requested"
                        ? "warning"
                        : apiResponse.status === "failed"
                        ? "error"
                        : "success"
                    }
                  />
                </TableCell>
              </TableRow>
            </TableBody>
          </TableContainer>
        </Grid>
      </Grid>
      <Tabs
        variant="scrollable"
        scrollButtons="auto"
        value={activeTab}
        onChange={handleChange}
        aria-label="file tabs"
        sx={{
          backgroundColor: "white", // Light teal background for the whole tabs bar
          boxShadow: "0 2px 4px rgba(0,0,0,0.1)", // Adding a subtle shadow under the tabs bar
          "& .MuiTabs-flexContainer": {
            gap: "10px", // Adds space between each tab/button
          },
        }}
      >
        <Tab
          key="histogram"
          label="Stacked Histogram"
          id="histogram"
          aria-controls={`tabpanel-histogram`}
        />
        <Tab
          key="crosstab"
          label="Crosstab Heatmap"
          id="crosstab"
          aria-controls={`tabpanel-crosstab`}
        />
        <Tab
          key="pie"
          label="Pie Chart"
          id="pie"
          aria-controls={`tabpanel-pie`}
        />
        <Tab
          key="stats"
          label="Statistical Analysis Table"
          id="stats"
          aria-controls={`tabpanel-stats`}
        />
        <Tab
          key="sankey"
          label="Sankey Diagrams"
          id="sankey"
          aria-controls={`tabpanel-sankey`}
        />

        <Tab
          key="sunburst"
          label="Sunburst Plots"
          id="sunburst"
          aria-controls={`tabpanel-sunburst`}
        />
      </Tabs>
      {isLoading ? (
        <div style={{ textAlign: "center", margin: "15px" }}>
          <CircularProgress />
        </div>
      ) : apiResponse?.status === "requested" ? (
        <div>
          <Alert severity="warning">Job in progress</Alert>
        </div>
      ) : apiResponse?.status === "failed" ? (
        <div>
          <Alert severity="error">Job failed!</Alert>
        </div>
      ) : (
        <div>
          <div
            hidden={activeTab !== 0}
            id="histogram"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <FormControl fullWidth margin="normal">
              <Autocomplete
                options={
                  apiResponse?.crosstab_plots
                    ? Object.keys(apiResponse?.crosstab_plots)
                    : []
                }
                getOptionLabel={(option) => option}
                value={stackedHistogramColumn}
                onChange={handleHistogramSelection}
                renderInput={(params) => (
                  <TextField
                    {...params}
                    label="Stacked Histogram Column"
                    variant="outlined"
                  />
                )}
                fullWidth
              />
            </FormControl>
            <StackedHistogram
              crosstab_plots={apiResponse.crosstab_plots}
              key_name={stackedHistogramColumn}
            />
          </div>

          <div
            hidden={activeTab !== 1}
            id="crosstab"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <FormControl fullWidth margin="normal">
              <Autocomplete
                options={
                  apiResponse?.crosstab_plots
                    ? Object.keys(apiResponse?.crosstab_plots)
                    : []
                }
                getOptionLabel={(option) => option}
                value={crosstabHeatmapColumn}
                onChange={handleCrosstabSelection}
                renderInput={(params) => (
                  <TextField
                    {...params}
                    label="Crosstab Heatmap Column"
                    variant="outlined"
                  />
                )}
                fullWidth
              />
            </FormControl>
            <CrosstabHeatmap
              crosstab_plots={apiResponse.crosstab_plots}
              key_name={crosstabHeatmapColumn}
            />
          </div>

          <div
            hidden={activeTab !== 2}
            id="pie"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <FormControl fullWidth margin="normal">
              <Autocomplete
                options={
                  apiResponse?.crosstab_plots
                    ? Object.keys(apiResponse?.crosstab_plots)
                    : []
                }
                getOptionLabel={(option) => option}
                value={pieChartColumn}
                onChange={handlePieChartSelection}
                renderInput={(params) => (
                  <TextField
                    {...params}
                    label="Pie Chart Column"
                    variant="outlined"
                  />
                )}
                fullWidth
              />
            </FormControl>
            <PieChart
              crosstab_plots={apiResponse.crosstab_plots}
              key_name={pieChartColumn}
            />
          </div>

          <div
            hidden={activeTab !== 3}
            id="stats"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <FormControl fullWidth margin="normal">
              <Autocomplete
                options={
                  apiResponse?.crosstab_plots
                    ? Object.keys(apiResponse?.crosstab_plots)
                    : []
                }
                getOptionLabel={(option) => option}
                value={statsTableColumn}
                onChange={handleStatsTableSelection}
                renderInput={(params) => (
                  <TextField
                    {...params}
                    label="Statistical Analysis Table"
                    variant="outlined"
                  />
                )}
                fullWidth
              />
            </FormControl>
            <StatisticsTable
              statistical_analysis={apiResponse.statistical_analysis}
              key_name={statsTableColumn}
            />
          </div>

          <div
            hidden={activeTab !== 4}
            id="stats"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <FormControl fullWidth margin="normal">
              <Autocomplete
                options={
                  apiResponse?.crosstab_plots
                    ? Object.keys(apiResponse?.crosstab_plots)
                    : []
                }
                getOptionLabel={(option) => option}
                value={sankeyDiagramColumn}
                onChange={handleSankeySelection}
                renderInput={(params) => (
                  <TextField
                    {...params}
                    label="Sankey Diagram Selection"
                    variant="outlined"
                  />
                )}
                fullWidth
              />
            </FormControl>
            <SankeyDiagram
              crosstab_plots={apiResponse.crosstab_plots}
              key_name={sankeyDiagramColumn}
            />
          </div>

          <div
            hidden={activeTab !== 5}
            id="stats"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <FormControl fullWidth margin="normal">
              <Autocomplete
                options={
                  apiResponse?.crosstab_plots
                    ? Object.keys(apiResponse?.crosstab_plots)
                    : []
                }
                getOptionLabel={(option) => option}
                value={sunburstPlotColumn}
                onChange={handleSunburstSelection}
                renderInput={(params) => (
                  <TextField
                    {...params}
                    label="Sunburst Plot Selection"
                    variant="outlined"
                  />
                )}
                fullWidth
              />
            </FormControl>
            <SunburstPlot
              crosstab_plots={apiResponse.crosstab_plots}
              key_name={sunburstPlotColumn}
            />
          </div>
        </div>
      )}
    </div>
  );
};

export default InferenceResponseViewer;
