import sys
import click
import csv
from pathlib import Path
from substrateinterface import SubstrateInterface

def extract_info_from_csv(file_path):
    """Extract wallet addresses and balances from a CSV file based on header names.
    Returns a list of dicts with the original fields plus starting and current balances
    for each account found"""
    new_rows = []
    
    with open(file_path, 'r') as f:
        reader = csv.DictReader(f)
        
        # Find address columns in headers
        balance_pairs = []
        new_fieldnames = reader.fieldnames.copy()
        for header in reader.fieldnames:
            if 'Coin' in header or 'coin' in header:
                continue
            if not any(keyword in header.lower() for keyword in ['account', 'address', 'custody', 'reserve']):
                continue
        
            # Report the address, original, and current balance fields for lookup on chain
            addr_col = header
            balance_col = reader.fieldnames[reader.fieldnames.index(header) + 1]
            cur_balance_col = 'current_' + balance_col
            balance_pairs.append((addr_col, cur_balance_col))
            # Add the current balance column to the fieldnames
            new_fieldnames.insert(new_fieldnames.index(balance_col) + 1, 'current_' + balance_col)
            # Replace the balance column with the original starting balance column
            new_fieldnames[new_fieldnames.index(balance_col)] = 'starting_' + balance_col

        for _, row in enumerate(reader, start=2):  # start=2 because row 1 is headers
            if not row or all(not cell for cell in row.values()):
                continue

            new_row = {}
            for field in new_fieldnames:
                # Copy the original starting balance to the new field
                if field.startswith('starting_'):
                    new_row[field] = row[field[9:]]
                    continue

                new_row[field] = row[field] if field in row else None
            new_rows.append(new_row)

    return new_rows, balance_pairs

def get_chain_balances(substrate: SubstrateInterface, acct_data, balance_pairs):
    """Query current balances for addresses from the chain."""
    
    lineno = 0
    for acct in acct_data:
        lineno += 1
        for addr_col, balance_col in balance_pairs:
            try:
                acct_addr = acct[addr_col]
                if not acct_addr:
                    click.echo(f"Skipping line # {lineno} due to missing address", err=False)
                    continue

                result = substrate.query(
                    module='System',
                    storage_function='Account',
                    params=[acct[addr_col]]
                )
                if result:
                    # Convert balance from chain format to human readable
                    balance = result.value['data']['free']
                    balance_formatted = balance / 10**9  # Adjust decimals as needed
                    acct[balance_col] = balance_formatted
            except Exception as e:
                click.echo(f"Error querying balance for {acct}: {str(e)}", err=True)
    
    return acct_data

@click.command()
@click.argument('csv_file', type=click.Path(exists=True))
@click.option('--rpc-url', default='ws://127.0.0.1:9944', help='Substrate RPC endpoint')
@click.option('--output', '-o', type=click.Path(), help='Output file path')
def main(csv_file, rpc_url, output):
    """Check wallet balances against the chain."""
    # Connect to substrate node
    try:
        substrate = SubstrateInterface(url=rpc_url,ss58_format=55,use_remote_preset=False)
    except Exception as e:
        click.echo(f"Failed to connect to RPC endpoint: {str(e)}", err=True)
        return

    # Extract addresses from CSV
    acct_data, balance_pairs = extract_info_from_csv(csv_file)
    if not acct_data:
        click.echo("No addresses found in CSV file")
        return

    click.echo(f"Found {len(acct_data)} accounts in CSV file with balance pairs: {balance_pairs}")
    # Get current chain balances
    chain_balances = get_chain_balances(substrate, acct_data, balance_pairs)

    # Write results to CSV
    output_path = output or Path(csv_file).with_suffix('.checked.csv')
    with open(output_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, chain_balances[0].keys())
        writer.writeheader()
        for acct in chain_balances:
            writer.writerow(acct)
    
    click.echo(f"Results written to {output_path}")

    

if __name__ == '__main__':
    main()