0

I am working on creating a component that includes scatterplots for multiple datasets and KDE (Kernel Density Estimation) plots for the x and y-coordinates of those datasets, in React with TypeScript using D3.js. So far, I have successfully implemented the KDE plot and the scatter plot. However, I am having trouble aligning the KDE plots to the right and top edges of my scatterplot, similar to how scatterhist produces output in MATLAB.

I tried using the .attr("transform", translate(${width},${0})) and other such combinations but I have not been able to place it within the svg canvas.

const Scatterplot: React.FC<{
  width: number;
  height: number;
  datasets: { x: number; y: number }[][];
  labels: string[];
  plotType: "ellipse" | "rectangle";
  p?: number;
  bandwidth?: number;
}> = ({
  width,
  height,
  datasets,
  labels,
  plotType,
  p = 0.95,
  bandwidth = 4,
}) => {
  const svgRef = useRef<SVGSVGElement | null>(null);

  useEffect(() => {
    const svgElement = svgRef.current;
    if (!svgElement) return;

    const svg = d3.select(svgElement);
    svg.selectAll("*").remove();
    svg.style("background-color", "white");
    const boundsWidth = width - MARGIN.right - MARGIN.left;
    const boundsHeight = height - MARGIN.top - MARGIN.bottom;
    const g = svg
      .attr("width", width + MARGIN.left + MARGIN.right)
      .attr("height", height + MARGIN.top + MARGIN.bottom)
      .append("g")
      .attr("transform", `translate(${MARGIN.left},${MARGIN.top})`);

    const colors = d3.schemeCategory10;

    const mu: [number, number][] = datasets.map((dataset) =>
      calculateMean(dataset)
    );

    const Sigma = datasets.map((dataset, index) =>
      calculateCovarianceMatrix(dataset, mu[index])
    );

    const maxX = d3.max(datasets.flat(), (d) => d.x) || 10;
    const maxY = d3.max(datasets.flat(), (d) => d.y) || 10;
    const minX = d3.min(datasets.flat(), (d) => d.x) || 0;
    const minY = d3.min(datasets.flat(), (d) => d.y) || 0;

    const bufferX = 3;
    const bufferY = 3;
    const xScale = d3
      .scaleLinear()
      .domain([minX - bufferX, maxX + bufferX])
      .range([0, width]);

    const yScale = d3
      .scaleLinear()
      .domain([minY - bufferY, maxY + bufferY])
      .range([height, 0]);

    const xAxis = d3.axisBottom(xScale);
    const yAxis = d3.axisLeft(yScale);

    g.append("g")
      .attr("transform", `translate(0,${height})`)
      .call(xAxis)
      .call((g) => g.selectAll(".domain, .tick line").attr("stroke", "#000"))
      .call((g) => g.selectAll(".tick text").attr("fill", "#000"))
      .append("text")
      .attr("class", "x-axis-label")
      .attr("fill", "#000")
      .attr("x", width / 2)
      .attr("y", MARGIN.bottom - 1)
      .attr("dy", "1em")
      .style("text-anchor", "middle");

    g.append("g")
      .call(yAxis)
      .call((g) => g.selectAll(".domain, .tick line").attr("stroke", "#000"))
      .call((g) => g.selectAll(".tick text").attr("fill", "#000"))
      .append("text")
      .attr("class", "y-axis-label")
      .attr("fill", "#000")
      .attr("transform", "rotate(-90)")
      .attr("x", -height / 2)
      .attr("y", -MARGIN.left + 10)
      .attr("dy", "1em")
      .style("text-anchor", "middle");

    datasets.forEach((data, i) => {
      // console.log(data);
      const color = colors[i % colors.length];

      g.selectAll(`.dot${i}`)
        .data(data)
        .enter()
        .append("circle")
        .attr("class", `dot${i}`)
        .attr("cx", (d) => xScale(d.x))
        .attr("cy", (d) => yScale(d.y))
        .attr("stroke", color)
        .attr("stroke-opacity", 1)
        .attr("stroke-width", 1)
        .attr("r", 2.5)
        .attr("fill", color)
        .attr("fill-opacity", 0.5);

      const legend = g
        .append("g")
        .attr("class", "legend")
        .attr("transform", `translate(${width - MARGIN.right},${i * 20})`);

      legend
        .append("rect")
        .attr("x", MARGIN.right)
        .attr("y", MARGIN.top -10)
        .attr("width", 18)
        .attr("height", 18)
        .style("fill", color);

      legend
        .append("text")
        .attr("x", MARGIN.right -10)
        .attr("y", MARGIN.top)
        .attr("dy", ".35em")
        .style("text-anchor", "end")
        .text(labels[i]);

      if (plotType === "rectangle") {
        [2, 3].forEach((sdMultiplier) => {
          const rectData = calculateConfidenceRectangle(
            mu[i],
            Sigma[i],
            sdMultiplier
          );
          const rectX = xScale(rectData.x);
          const rectY = yScale(rectData.y + rectData.height); 
          const rectWidth =
            xScale(rectData.x + rectData.width) - xScale(rectData.x);
          const rectHeight =
            yScale(rectData.y) - yScale(rectData.y + rectData.height);

          g.append("rect")
            .attr("x", rectX)
            .attr("y", rectY)
            .attr("width", rectWidth)
            .attr("height", rectHeight)
            .style("fill", "none")
            .style("stroke", color)
            .style("stroke-width", 1.5)
            .style("stroke-dasharray", sdMultiplier === 3 ? "4 2" : "none");
        });
      } else if (plotType === "ellipse") {
        [0.95, 0.99].forEach((confidence, j) => {
          const ellipseData = plotErrorEllipse(mu[i], Sigma[i], confidence);

          const cx = xScale(ellipseData.cx);
          const cy = yScale(ellipseData.cy);
          const rx = xScale(mu[i][0] + ellipseData.rx) - xScale(mu[i][0]);
          const ry = yScale(mu[i][1]) - yScale(mu[i][1] + ellipseData.ry);

          g.append("ellipse")
            .attr("cx", cx)
            .attr("cy", cy)
            .attr("rx", Math.abs(rx))
            .attr("ry", Math.abs(ry))
            .attr("transform", `rotate(${ellipseData.angle}, ${cx}, ${cy})`)
            .style("fill", "none")
            .style("stroke", color)
            .style("stroke-width", 1.5)
            .style("stroke-dasharray", j === 1 ? "4 2" : "none");
        });
      }
      const kdeX = kernelDensityEstimator(
        epanechnikovKernel(bandwidth),
        xScale.ticks(100)
      );
      const kdeY = kernelDensityEstimator(
        epanechnikovKernel(bandwidth),
        yScale.ticks(100)
      );
      const kdeDataX = kdeX(data.map((d) => d.x));
      const lineX = d3
        .line()
        .x((d) => xScale(d[0]))
        .y((d) => yScale(d[1]))
        .curve(d3.curveBasis);
      g.append("path")
        .datum(kdeDataX)
        .attr("class", "kde-x")
        .attr("d", lineX)
        .attr("fill", "none")
        .attr("stroke", color)
        .attr("stroke-width", 1);

      const kdeDataY = kdeY(data.map((d) => d.y));
      const lineY = d3
        .line()
        .x((d) => xScale(d[1]))
        .y((d) => yScale(d[0]))
        .curve(d3.curveBasis);

      g.append("path")
        .datum(kdeDataY)
        .attr("class", "kde-y")
        .attr("d", lineY)
        .attr("fill", "none")
        .attr("stroke", color)
        .attr("stroke-width", 1);
    });
  }, [width, height, datasets, labels, plotType, p, bandwidth]);

  return (
    <div>
      <svg ref={svgRef}></svg>
    </div>
  );
};

export default Scatterplot;

Current output:

enter image description here

Desired output:

enter image description here

2
  • 1
    Going to be very hard to give you a definitive answer given the small amount of non-reproducible code you show... But I'd approach the problem by placing the two KDE lines in their own g element and then translating each of those into the right or top margins respectively. Commented Jun 12, 2024 at 17:17
  • 1
    @Mark Your suggestion worked, I created 2 elements one for each plot. Thank you so much! Commented Jun 12, 2024 at 17:48

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.