from net_3tensor access *;

pen mode_label_pen = magenta + fontsize(8pt);
picture P[] = new picture[];
for (int i = 0; i < 6; ++i) {
  picture X;
  P[i] = X;
}

// STEP 1 ==============================================================================

network N1 = three_tensor(attach_input=true);
N1.tensors[0].label = "$\alpha$";
N1.tensors[1].label = "$\beta$";
N1.tensors[2].label = "$\gamma$";
N1.tensors[3].label = "$A$";
N1.tensors[4].label = "$B$";

N1.draw(P[0]);

label(P[0], "$i'$", N1.tensors[0].mid + (-7.5, -15), mode_label_pen);
label(P[0], "$k$", N1.tensors[0].mid + (7.5, -15), mode_label_pen);

label(P[0], "$k'$", N1.tensors[1].mid + (-8.2, -15), mode_label_pen);
label(P[0], "$j'$", N1.tensors[1].mid + (8.5, -15), mode_label_pen);

label(P[0], "$i$", N1.tensors[2].mid + (-7.5, 15), mode_label_pen);
label(P[0], "$j$", N1.tensors[2].mid + (7.5, 15), mode_label_pen);

label(P[0], "$\ell$", (N1.tensors[0].mid + N1.tensors[1].mid + N1.tensors[2].mid) / 3 + (4,10), mode_label_pen);

draw(P[0], box(N1.tensors[0].mid + (-15,15),
             N1.tensors[3].mid + (15, -15)),
     p=blue+dashed);

// STEP 2 ==============================================================================

network N2;

tensor Aa = N2.add_tensor((0, -17.5), 10 ,10);

//Aa.label="\begin{tabular}{c} a\\x\end{tabular}";

tensor beta = N2.add_tensor((50, 0), 10, 10, label="$\beta$");
tensor gamma = N2.add_tensor((25, 50), 10, 10, label="$\gamma$");
mode_join lp = N2.add_mode_join((25, 25));
N2.join(Aa.mid + (0,10), lp.mid);
N2.join(beta.mid + (0,10), lp.mid);
N2.join(gamma.mid + (0,-10), lp.mid);

N2.add_path(smooth_vertical_path(beta.mid + (-5,-10), beta.mid + (-5, -25)));
N2.add_path(smooth_vertical_path(beta.mid + (5,-10), beta.mid + (5, -25)));
N2.add_path(smooth_vertical_path(gamma.mid + (-5,10), gamma.mid + (-5, 25)));
N2.add_path(smooth_vertical_path(gamma.mid + (5,10), gamma.mid + (5, 25)));

tensor B = N2.add_tensor(beta.mid + (0, -35), 10, 10, pen=green, label="$B$");

N2.draw(P[1]);

label(P[1], "$k'$", N2.tensors[1].mid + (-8.2, -15), mode_label_pen);
label(P[1], "$j'$", N2.tensors[1].mid + (8.5, -15), mode_label_pen);

label(P[1], "$i$", N2.tensors[2].mid + (-7.5, 15), mode_label_pen);
label(P[1], "$j$", N2.tensors[2].mid + (7.5, 15), mode_label_pen);

label(P[1], "$\ell$", (N1.tensors[0].mid + N1.tensors[1].mid + N1.tensors[2].mid) / 3 + (4,10), mode_label_pen);

draw(P[1], box(N2.tensors[1].mid + (-15,15),
             N2.tensors[3].mid + (15, -15)),
     p=blue+dashed);



// STEP 3 ==============================================================================

network N3;

tensor Aa = N3.add_tensor((0, -17.5), 10 ,10);
tensor Bb = N3.add_tensor((50, -17.5), 10, 10);
tensor gamma = N3.add_tensor((25, 50), 10, 10, label="$\gamma$");
mode_join lp = N3.add_mode_join((25, 25));
N3.join(Aa.mid + (0,10), lp.mid);
N3.join(Bb.mid + (0,10), lp.mid);
N3.join(gamma.mid + (0,-10), lp.mid);

N3.add_path(smooth_vertical_path(gamma.mid + (-5,10), gamma.mid + (-5, 25)));
N3.add_path(smooth_vertical_path(gamma.mid + (5,10), gamma.mid + (5, 25)));

N3.draw(P[2]);

label(P[2], "$i$", N2.tensors[2].mid + (-7.5, 15), mode_label_pen);
label(P[2], "$j$", N2.tensors[2].mid + (7.5, 15), mode_label_pen);
label(P[2], "$\ell$", (N1.tensors[0].mid + N1.tensors[1].mid + N1.tensors[2].mid) / 3 + (4,10), mode_label_pen);


draw(P[2], box(N3.tensors[0].mid + (-15,15),
             N3.tensors[1].mid + (15, -15)),
     p=blue+dashed);

// STEP 4 ==============================================================================

network N4;

tensor AaBb = N4.add_tensor((25, -17.5), 10 ,10);
tensor gamma = N4.add_tensor((25, 50), 10, 10, label="$\gamma$");
N4.join(AaBb.mid + (0,10), gamma.mid + (0, -10));

N4.add_path(smooth_vertical_path(gamma.mid + (-5,10), gamma.mid + (-5, 25)));
N4.add_path(smooth_vertical_path(gamma.mid + (5,10), gamma.mid + (5, 25)));

N4.draw(P[3]);

label(P[3], "$i$", N1.tensors[2].mid + (-7.5, 15), mode_label_pen);
label(P[3], "$j$", N1.tensors[2].mid + (7.5, 15), mode_label_pen);
label(P[3], "$\ell$", (N1.tensors[0].mid + N1.tensors[1].mid + N1.tensors[2].mid) / 3 + (4,10), mode_label_pen);


draw(P[3], box(N4.tensors[0].mid + (-15,-15),
             N4.tensors[1].mid + (15, 15)),
     p=blue+dashed);

// STEP 5 ==============================================================================

network N5;

tensor C = N5.add_tensor((25, (50-17.5)/2), 15, 10, label="$A \cdot B$");

N5.add_path(smooth_vertical_path(C.mid + (-5,10), C.mid + (-5, 25)));
N5.add_path(smooth_vertical_path(C.mid + (5,10), C.mid + (5, 25)));

N5.draw(P[4]);

label(P[4], "$i$", C.mid + (-7.5, 20), mode_label_pen);
label(P[4], "$j$", C.mid + (7.5, 20), mode_label_pen);


// STEP 6 ==============================================================================

network N6 = three_tensor(attach_input=true);
N6.draw(P[5], draw_execution=true);

// COMBINE ==============================================================================

real GAP = 90;
for (int i = 0; i < 6; ++i) {
  real pos = GAP*i;
  if (i == 2) pos += 10;
  if (i == 4) pos -= 30;
  if (i == 5) pos -= 30;
  add(P[i], (pos, 0));

  if (i > 0 && i < 5) {
    real arr_x = pos-10, arr_y = 20;
    if (i == 2) arr_x -= 5;
    if (i == 3) arr_x += 10;
    if (i == 4) arr_x += 10;
    draw((arr_x-15, arr_y-2)--(arr_x,arr_y-2), p=blue);
    draw((arr_x-15, arr_y+2)--(arr_x,arr_y+2), p=blue);
    draw((arr_x-6, arr_y-6)--(arr_x+3,arr_y)--(arr_x-6,arr_y+6), p=blue);
  }
  if (i == 5) {
    real arr_x = pos-25, arr_y = 20;
    draw((arr_x-15, arr_y-2)--(arr_x,arr_y-2), p=blue);
    draw((arr_x-15, arr_y+2)--(arr_x,arr_y+2), p=blue);
    draw((arr_x-10, arr_y+6)--(arr_x-5,arr_y+6)--(arr_x-7.5,arr_y+10)--cycle, p=blue);

  }
}

draw(box((-20, -55), (GAP*5 - 75, 80)), p=blue);
