# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.14.5
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
#   metadata:
#     section: mf6
# ---

# # Working with the Multi-node Well (MNW2) Package

import os

# +
import sys
from tempfile import TemporaryDirectory

import numpy as np

try:
    import pandas as pd
except:
    pass

# run installed version of flopy or add local path
try:
    import flopy
except:
    fpth = os.path.abspath(os.path.join("..", ".."))
    sys.path.append(fpth)
    import flopy

print(sys.version)
print("numpy version: {}".format(np.__version__))
try:
    print("pandas version: {}".format(pd.__version__))
except:
    pass
print("flopy version: {}".format(flopy.__version__))
# -

# ### Make an MNW2 package from scratch

# +
# temporary directory
temp_dir = TemporaryDirectory()
model_ws = temp_dir.name

m = flopy.modflow.Modflow("mnw2example", model_ws=model_ws)
dis = flopy.modflow.ModflowDis(
    nrow=5, ncol=5, nlay=3, nper=3, top=10, botm=0, model=m
)
# -

# ### MNW2 information by node
# (this could be prepared externally from well reconds and read in from a csv or excel file)
# * this table has two multi-node wells, the first (well1) consisting of two nodes that are manually specified
# (where the variable **rw** is specified by node)
# * node that some variables that are constant for the whole well are also included (losstype, zpump, etc.)

node_data = pd.DataFrame(
    [
        [1, 1, 9.5, 7.1, "well1", "skin", -1, 0, 0, 0, 1.0, 2.0, 5.0, 6.2],
        [1, 1, 7.1, 5.1, "well1", "skin", -1, 0, 0, 0, 0.5, 2.0, 5.0, 6.2],
        [3, 3, 9.1, 3.7, "well2", "skin", -1, 0, 0, 0, 1.0, 2.0, 5.0, 4.1],
    ],
    columns=[
        "i",
        "j",
        "ztop",
        "zbotm",
        "wellid",
        "losstype",
        "pumploc",
        "qlimit",
        "ppflag",
        "pumpcap",
        "rw",
        "rskin",
        "kskin",
        "zpump",
    ],
)
node_data

# #### convert the DataFrame to a recarray for compatibility with flopy

node_data = node_data.to_records()
node_data

# ### Stress period information
# (could also be developed externally)

stress_period_data = pd.DataFrame(
    [
        [0, "well1", 0],
        [1, "well1", 100.0],
        [0, "well2", 0],
        [1, "well2", 1000.0],
    ],
    columns=["per", "wellid", "qdes"],
)
stress_period_data

pers = stress_period_data.groupby("per")
stress_period_data = {i: pers.get_group(i).to_records() for i in [0, 1]}
stress_period_data

# ### Make ``ModflowMnw2`` package object
# * note that extraneous columns in node_data and stress_period_data are ignored
# * if itmp is positive, it must equal the number of active wells being specified in ``stress_period_data``, otherwise the package class will raise an error.

mnw2 = flopy.modflow.ModflowMnw2(
    model=m,
    mnwmax=2,
    node_data=node_data,
    stress_period_data=stress_period_data,
    itmp=[2, 2, -1],  # reuse second per pumping for last stress period
)

# "nodtot" is computed automatically
mnw2.nodtot

pd.DataFrame(mnw2.node_data)

pd.DataFrame(mnw2.stress_period_data[0])

pd.DataFrame(mnw2.stress_period_data[1])

tmp = flopy.modflow.ModflowMnw2(
    model=m,
    itmp=[1, 1, -1],  # reuse second per pumping for last stress period
)

# ### empty ``node_data`` and ``stress_period_data`` tables can also be generated by the package class, and then filled

node_data = tmp.get_empty_node_data(3)
node_data

# ### Mnw objects
# at the base of the flopy mnw2 module is the **Mnw** object class, which describes a single multi-node well.
# A list or dict of **Mnw** objects can be used to build a package (using the example above):
# ```
# flopy.modflow.ModflowMnw2(model=m, mnwmax=2,
#                  mnw=<dict or list of Mnw objects>,
#                  itmp=[1, 1, -1], # reuse second per pumping for last stress period
#                  )
# ```
# or if node_data and stress_period_data are supplied, the **Mnw** objects are created on initialization of the ModflowMnw2 class instance, and assigned to the ```.mnw``` attribute, as items in a dictionary keyed by ```wellid```.

mnw2.mnw

mnw2.mnw["well1"].rw

# Note that Mnw object attributes for variables that vary by node are lists (e.g. ``rw`` above)
#
# #### Each Mnw object has its own ``node_data`` and ``stress_period_data``

pd.DataFrame(mnw2.mnw["well1"].node_data)

# #### Instead of a dict keyed by stress period, Mnw.stress_period_data is a recarray with pumping data listed by stress period for that well
# * note that data for period 2, where ``itmp`` < 1, is shown (was copied from s.p. 1 during construction of the **Mnw** object)

pd.DataFrame(mnw2.mnw["well2"].stress_period_data)

# ### Build the same package using only the ``Mnw`` objects

mnw2fromobj = flopy.modflow.ModflowMnw2(
    model=m,
    mnwmax=2,
    mnw=mnw2.mnw,
    itmp=[2, 2, -1],  # reuse second per pumping for last stress period
)

pd.DataFrame(mnw2fromobj.node_data)

pd.DataFrame(mnw2fromobj.stress_period_data[0])

pd.DataFrame(mnw2fromobj.stress_period_data[1])

# ### By default, the ``node_data`` and ``stress_period_data`` tables attached to the ``ModflowMnw2`` package class are definitive
# * on writing of the package output (``mnw2.write_file()``), the **Mnw** objects are regenerated from the tables. This setting is controlled by the default argument ``use_tables=True``. To write the package file using the **Mnw** objects (ignoring the tables), use ``mnw2.write_file(use_tables=False)``.

per1 = flopy.modflow.ModflowMnw2.get_empty_stress_period_data(itmp=2)
per1

# ### Write an MNW2 package file and inspect the results

mnw2.write_file(os.path.join(model_ws, "test.mnw2"))

junk = [
    print(l.strip("\n"))
    for l in open(os.path.join(model_ws, "test.mnw2")).readlines()
]

# ### Load some example MNW2 packages

path = os.path.join("..", "..", "examples", "data", "mnw2_examples")
m = flopy.modflow.Modflow("MNW2-Fig28", model_ws=model_ws)
dis = flopy.modflow.ModflowDis.load(os.path.join(path, "MNW2-Fig28.dis"), m)

m.get_package_list()

mnw2pth = os.path.join(path, "MNW2-Fig28.mnw2")
mnw2 = flopy.modflow.ModflowMnw2.load(mnw2pth, m)

pd.DataFrame(mnw2.node_data)

pd.DataFrame(mnw2.stress_period_data[0])

mnw2.mnw

pd.DataFrame(mnw2.mnw["well-a"].stress_period_data)

path = os.path.join("..", "..", "examples", "data", "mnw2_examples")
m = flopy.modflow.Modflow("br", model_ws=model_ws)
mnw2 = flopy.modflow.ModflowMnw2.load(path + "/BadRiver_cal.mnw2", m)

df = pd.DataFrame(mnw2.node_data)
df.loc[:, df.sum(axis=0) != 0]

# + pycharm={"name": "#%%\n"}
try:
    # ignore PermissionError on Windows
    temp_dir.cleanup()
except:
    pass
