aboutsummaryrefslogtreecommitdiffstats
path: root/examples/tutorials/finance_manager/part2/financemodel.py
blob: 0326697ba43c8706f8434b4203c02c231e400fc4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (C) 2024 The Qt Company Ltd.
# SPDX-License-Identifier: LicenseRef-Qt-Commercial OR BSD-3-Clause

from datetime import datetime
from dataclasses import dataclass
from enum import IntEnum
from collections import defaultdict

from PySide6.QtCore import (QAbstractListModel, QEnum, Qt, QModelIndex, Slot,
                            QByteArray)
from PySide6.QtQml import QmlElement
import database

QML_IMPORT_NAME = "Finance"
QML_IMPORT_MAJOR_VERSION = 1


@QmlElement
class FinanceModel(QAbstractListModel):

    @QEnum
    class FinanceRole(IntEnum):
        ItemNameRole = Qt.DisplayRole
        CategoryRole = Qt.UserRole
        CostRole = Qt.UserRole + 1
        DateRole = Qt.UserRole + 2
        MonthRole = Qt.UserRole + 3

    @dataclass
    class Finance:
        item_name: str
        category: str
        cost: float
        date: str

        @property
        def month(self):
            return datetime.strptime(self.date, "%d-%m-%Y").strftime("%B %Y")

    def __init__(self, parent=None) -> None:
        super().__init__(parent)
        self.session = database.Session()
        self.m_finances = self.load_finances()

    def load_finances(self):
        finances = []
        for finance in self.session.query(database.Finance).all():
            finances.append(self.Finance(finance.item_name, finance.category, finance.cost,
                                         finance.date))
        return finances

    def rowCount(self, parent=QModelIndex()):
        return len(self.m_finances)

    def data(self, index: QModelIndex, role: int):
        row = index.row()
        if row < self.rowCount():
            finance = self.m_finances[row]
            if role == FinanceModel.FinanceRole.ItemNameRole:
                return finance.item_name
            if role == FinanceModel.FinanceRole.CategoryRole:
                return finance.category
            if role == FinanceModel.FinanceRole.CostRole:
                return finance.cost
            if role == FinanceModel.FinanceRole.DateRole:
                return finance.date
            if role == FinanceModel.FinanceRole.MonthRole:
                return finance.month
        return None

    @Slot(result=dict)
    def getCategoryData(self):
        category_data = defaultdict(float)
        for finance in self.m_finances:
            category_data[finance.category] += finance.cost
        return dict(category_data)

    def roleNames(self):
        roles = super().roleNames()
        roles[FinanceModel.FinanceRole.ItemNameRole] = QByteArray(b"item_name")
        roles[FinanceModel.FinanceRole.CategoryRole] = QByteArray(b"category")
        roles[FinanceModel.FinanceRole.CostRole] = QByteArray(b"cost")
        roles[FinanceModel.FinanceRole.DateRole] = QByteArray(b"date")
        roles[FinanceModel.FinanceRole.MonthRole] = QByteArray(b"month")
        return roles

    @Slot(int, result='QVariantMap')
    def get(self, row: int):
        finance = self.m_finances[row]
        return {"item_name": finance.item_name, "category": finance.category,
                "cost": finance.cost, "date": finance.date}

    @Slot(str, str, float, str)
    def append(self, item_name: str, category: str, cost: float, date: str):
        finance = self.Finance(item_name, category, cost, date)
        self.session.add(database.Finance(item_name=item_name, category=category, cost=cost,
                                          date=date))
        self.beginInsertRows(QModelIndex(), 0, 0)  # Insert at the front
        self.m_finances.insert(0, finance)  # Insert at the front of the list
        self.endInsertRows()
        self.session.commit()