Cod sursa(job #2778979)

Utilizator FunnyStockyMihnea Andreescu FunnyStocky Data 2 octombrie 2021 14:20:23
Problema Flux maxim de cost minim Scor 100
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 2.54 kb
#include <cstdio>
#include <vector>
#include <queue>
#include <algorithm>
#include <iostream>
#include <cassert>

using namespace std;

typedef long long ll;
const int N = 350 + 7;
const int INF = (int) 1e9;
int n, m, s, d;
int cap[N][N];
int cost[N][N];
vector<int> g[N];
int dmin[N], par[N], mn[N];
bool inq[N];

void add(int a, int b, int cp, int cst) {
  g[a].push_back(b);
  g[b].push_back(a);
  cap[a][b] = cp;
  cost[a][b] = cst;
  cost[b][a] = -cst;
}

void belf() {
  for (int i = 1; i <= n; i++) {
    dmin[i] = INF;
    par[i] = -1;
  }
  dmin[s] = 0;
  mn[s] = INF;
  queue<int> q;
  q.push(s);
  inq[s] = 1;
  while (!q.empty()) {
    int a = q.front();
    inq[a] = 0;
    q.pop();
    for (auto &b : g[a]) {
      if (cap[a][b] > 0 && dmin[a] + cost[a][b] < dmin[b]) {
        dmin[b] = dmin[a] + cost[a][b];
        par[b] = a;
        mn[b] = min(cap[a][b], mn[a]);
        if (!inq[b]) {
          q.push(b);
          inq[b] = 1;
        }
      }
    }
  }
}

int dd[N], dnew[N];

void dij() {
  for (int i = 1; i <= n; i++) {
    dd[i] = INF;
    par[i] = -1;
    dnew[i] = dmin[i];
  }
  mn[s] = INF;
  dnew[s] = 0;
  priority_queue<pair<int, int>> q;
  q.push({0, s});
  dd[s] = 0;
  while (!q.empty()) {
    int a = q.top().second;
    if (dd[a] != -q.top().first) {
      q.pop();
      continue;
    }
    q.pop();
    for (auto &b : g[a]) {
      if (cap[a][b] > 0 && dd[a] + cost[a][b] + dmin[a] - dmin[b] < dd[b]) {
        dnew[b] = dnew[a] + cost[a][b];
        dd[b] = dd[a] + cost[a][b] + dmin[a] - dmin[b];
        par[b] = a;
        mn[b] = min(cap[a][b], mn[a]);
        q.push({-dd[b], b});
      }
    }
  }
  for (int i = 1; i <= n; i++) {
    dmin[i] = dnew[i];
  }
}

int main() {
  freopen ("fmcm.in", "r", stdin);
  freopen ("fmcm.out", "w", stdout);

  scanf("%d %d %d %d", &n, &m, &s, &d);
  for (int i = 1; i <= m; i++) {
    int a, b, c, d;
    scanf("%d %d %d %d", &a, &b, &c, &d);
    add(a, b, c, d);
  }
  belf();
  ll sol = 0;
  while (1) {
    dij();
    if (par[d] == -1) {
      break;
    }
    sol += (ll) mn[d] * dmin[d];
    int val = mn[n];
    int node = d;
    vector<int> path;
    while (1) {
      path.push_back(node);
      if (node == s) {
        break;
      }
      node = par[node];
    }
    reverse(path.begin(), path.end());
    for (int i = 0; i + 1 < (int) path.size(); i++) {
      cap[path[i]][path[i + 1]] -= mn[d];
      cap[path[i + 1]][path[i]] += mn[d];
    }
  }
  printf("%lld\n", sol);
}