from networks access *;

network fft5(bool attach_input=false, bool reverse_output=false, bool label_nodes=false, int step=0, int line_spacing=15) {
  network N;

  tensor butterfly[] = new tensor[];
  int butterflies = 5 - quotient((step+1), 2);
  for (int i = 0; i < butterflies; ++i) {
    string L = label_nodes ? "$H_2$" : "";
    butterfly.push(N.add_tensor((120*i+40, -line_spacing*i), 10, 10, label=L));
  }

  tensor twiddle[] = new tensor[];
  int twiddles = 4 - quotient(step, 2);
  for (int i = 0; i < twiddles; ++i) {
    string L = label_nodes ? "$R^{5," + string(3-i) + "}$" : "";
    twiddle.push(N.add_tensor((120*i+100, 20), 40, 10, label=L));
  }

  int input_x = 120*twiddles + 80 - 40*(step % 2);
  tensor input_node;
  if (attach_input) {
    pen p = step > 0 ? defaultpen : green;
    input_node = N.add_tensor((input_x, -2*line_spacing), 10, line_spacing*2+10, pen=p);
  }
  
  pair output[] = new pair[];
  for (int i = 0; i < 5; ++i) {
    output.push((10, -line_spacing*i));
  }

  pair input[] = new pair[];
  for (int i = 0; i < 5; ++i) {
    input.push((input_x-10, -line_spacing*i));
  }

  if (reverse_output) {
    for (int i = 0; i < 5; ++i) {
      pair flipped = (output[i].x-20, -line_spacing*4-output[i].y);
      N.join(output[i], flipped);
      N.join(flipped, flipped + (-10, 0));
    }
  }

  pair left[] = new pair[];
  pair right[] = new pair[];
  for (int i = 0; i < 5; ++i) {
    right[i] = output[i];
  }
  for (int i = 0; i < twiddles; ++i) {
    for (int j = 0; j < 5; ++j) {
      left[j] = right[j];
      right[j] = N.add_mode_join((120*i+100-30+15*j, -line_spacing*j)).mid;
      N.join(left[j], right[j]);
      N.join(right[j], (120*i+100-30+15*j, line_spacing));
    }
  }
  for (int i =0; i < 5; ++i) {
    N.join(right[i], input[i]);
  }

  // Execution

  if (attach_input) {
    tensor steps[] = new tensor[];
    for (int i = 0; i < twiddles; ++i) {
      steps.push(butterfly[i]);
      steps.push(twiddle[i]);
    }
    if (butterflies > twiddles) {
      steps.push(butterfly[butterflies-1]);
    }
    execution_node prev = input_node.exec;
    for (int i = steps.length-1; i >= 0; --i) {
      execution_node next = N.add_execution_node((steps[i].mid.x+10, -line_spacing*6));
      N.exec_join(prev, next);
      N.exec_join(steps[i].exec, next);
      prev = next;
    }
  }
  return N;
}


network convolution(bool attach_input=false, int line_spacing=15) {
  network N;
  N.add_sub_network(fft5(reverse_output=true,
                         attach_input=attach_input));
  if (attach_input) {
    N.sub_networks[0].tensors.pop();
  }
  N.add_sub_network(fft5(reverse_output=true,
                         attach_input=attach_input), shift=(120*5+10, 5*line_spacing));
  N.add_sub_network(fft5(reverse_output=true,
                         attach_input=attach_input), shift=(120*5+10, -5*line_spacing));
  int input_x = 120*4 + 70;
  for (int i = 0; i < 5; ++i) {
    N.add_mode_join((input_x, -i*line_spacing));
    N.join((input_x, -i*line_spacing), (120*5-10, 5*line_spacing-i*line_spacing));
    N.join((input_x, -i*line_spacing), (120*5-10, -5*line_spacing-i*line_spacing));
  }

  tensor last_twiddle = N.add_tensor((-70, 20), 40, 10);
  for (int i = 0; i < 5; ++i) {
    int x = -70-30+15*i;
    N.add_mode_join((x, -i*line_spacing));
    N.join((x, -i*line_spacing), (x, 20));
    N.join((-10, -i*line_spacing), (-120, -i*line_spacing));
  }
  if (attach_input) {
    execution_node root = N.sub_networks[0].exec[N.sub_networks[0].exec.length-1];
    execution_node subroot1 = N.add_execution_node(root.mid + N.sub_network_shifts[1]);
    execution_node top_input = N.sub_networks[0].exec[9]; // ugh, magic constant count exceeded now
    N.exec_join(subroot1, top_input);
    execution_node subroot2 = N.add_execution_node(root.mid + N.sub_network_shifts[2]);
    N.exec_join(subroot2, top_input);

    execution_node newroot = N.add_execution_node((last_twiddle.exec.mid.x+10, root.mid.y));
    N.exec_join(root, newroot);
    N.exec_join(last_twiddle.exec, newroot);
  }
  return N;
}
