from pathlib import Path
from nanotools.utils import Quantity, ureg
import io
import json
import numpy as np
import toml
def json_decoder(obj, ignore: list = None):
    if ignore is None:
        ignore = []
    if "__class__" in obj:
        if "Quantity" not in ignore and obj["__class__"] == "Quantity":
            return obj["magnitude"] * ureg(obj["units"])
        if "ndarray" not in ignore and obj["__class__"] == "ndarray":
            array = np.array(obj["array"]).reshape(tuple(obj["shape"]), order="F")
            if "iscomplex" not in obj.keys():
                return array
            if obj["iscomplex"]:
                return array + 1.0j * np.array(obj["imag"]).reshape(
                    tuple(obj["shape"]), order="F"
                )
            else:
                return array
    return obj
def json_read(filename):
    if isinstance(filename, io.TextIOWrapper):
        path = Path(filename.name)
    else:
        path = Path(filename)
    if not path.exists():
        raise FileNotFoundError
    with path.open() as fid:
        if path.suffix == ".json":
            adict = json.load(fid, object_hook=json_decoder)
        elif path.suffix == ".toml":
            adict = toml.load(fid)
        else:
            print(
                f"WARNING: Unknown extension {path.suffix}, attempting to load as json file."
            )
            adict = json.load(fid, object_hook=json_decoder)
    return adict
def json_write(filename, adict):
    if isinstance(filename, io.TextIOWrapper):
        path = Path(filename.name)
    else:
        path = Path(filename)
    with path.open(mode="w") as fid:
        json.dump(adict, fid, indent=2, sort_keys=True, cls=SpecialCaseEncoder)
[docs]
class SpecialCaseEncoder(json.JSONEncoder):
[docs]
    def default(self, obj):
        if isinstance(obj, Quantity):
            return dict(
                __class__="Quantity",
                magnitude=obj.m,
                units=str(obj.u),
            )
        if isinstance(obj, np.ndarray):
            d = dict(
                __class__="ndarray",
                iscomplex=np.iscomplexobj(obj),
                ndim=obj.ndim,
                shape=obj.shape,
                array=np.real(obj).flatten(order="F").tolist(),
            )
            if d["iscomplex"]:
                d["imag"] = np.imag(obj).flatten(order="F").tolist()
            return d
        if isinstance(obj, Path):
            return obj.resolve().as_posix()
        try:
            return json.JSONEncoder.default(self, obj)
        except TypeError:
            print("Warning: TypeError in JSONEncoder. Pass.")