superset|源码学习|BaseViz类(viz.py)

代码执行逻辑|调用逻辑

  • 已保存图表在制作页面运行,获取csv,查询结果等数据,调用的 views.core.py 中的superset类下的slice_json
  • BaseViz类下的get_payload方法,综合调用类的方法,返回最终的数据;其中图表所需数据通过调用get_data方法;图表配置参数调用query_obj方法处理
  • BaseViz类下的query_obj方法,处理传入的form_data,就是页面的参数配置内容;继承的图表子类重写该方法,检测form_data内容,通过适当的报错信息,提示用户当前图表的参数配置方法
  • BaseViz类下get_data方法,继承的图表子类重写该方法,可以处理传入的可视化配置项,或者进一步修改查询的结果,最终结果就是前端接收的数据
    def get_payload(self, query_obj=None):
        """Returns a payload of metadata and data"""
        self.run_extra_queries()  #该方法目前为空;没有具体的实现内容
        payload = self.get_df_payload(query_obj)  # 核心部分,用查询条件获取数据;这里返回仅仅通过前端设置的查询条件返回的原始数据

        df = payload.get('df')
        if self.status != utils.QueryStatus.FAILED:
            if df is not None and df.empty:
                payload['error'] = 'No data'
            else:
          # 调用get_data对查询数据进一步处理
payload[
'data'] = self.get_data(df) if 'df' in payload: del payload['df'] return payload
class BaseViz(object):
    """All visualizations derive this base class"""

    viz_type = None
    verbose_name = 'Base Viz'  # 图例的别名
    credits = ''  # 参考对象链接
    is_timeseries = False  # 是否为时间序列
    default_fillna = 0
    cache_type = 'df'
    enforce_numerical_metrics = True

    def __init__(self, datasource, form_data, force=False):
        # 数据源为空
        if not datasource:
            raise Exception(_('Viz is missing a datasource'))

        self.datasource = datasource
        self.request = request
        # form_data 字典格式, 左侧条件设置内容
        self.viz_type = form_data.get('viz_type')
        self.form_data = form_data

        self.query = ''
        # token 权限??
        self.token = self.form_data.get(
            'token', 'token_' + uuid.uuid4().hex[:8])

        self.groupby = self.form_data.get('groupby') or []
        # timedelta 是pd的数据格式之一; datetime.timedelta()==> 两个时间之间的时间差
        self.time_shift = timedelta()

        self.status = None
        self.error_message = None
        self.force = force

        # Keeping track of whether some data came from cache
        # this is useful to trigger the <CachedLabel /> when
        # in the cases where visualization have many queries
        # (FilterBox for instance)
        self._some_from_cache = False
        self._any_cache_key = None
        self._any_cached_dttm = None
        self._extra_chart_data = []

        self.process_metrics(# 处理刻度值
    def process_metrics(self):
        # metrics in TableViz is order sensitive, so metric_dict should be
        # OrderedDict: 有确定顺序的字典
        self.metric_dict = OrderedDict()
        fd = self.form_data
        for mkey in METRIC_KEYS:
            # METRIC_KEYS 是全局变量;
            val = fd.get(mkey)
            if val:
                # 构造成 list 固定的格式
                if not isinstance(val, list):
                    val = [val]
                for o in val:
                    label = self.get_metric_label(o)  # 调用get_metric_label()方法
                    if isinstance(o, dict):
                        # 构造 ‘label’: label 的字典形式
                        o['label'] = label
                    self.metric_dict[label] = o

        # Cast to list needed to return serializable object in py3
        self.all_metrics = list(self.metric_dict.values())
        self.metric_labels = list(self.metric_dict.keys())

    # 获取刻度值
    def get_metric_label(self, metric):
        if isinstance(metric, str):
            return metric

        if isinstance(metric, dict):
            # label: "SUM(sum_boys)"
            metric = metric.get('label')
        # tbale类型 ??
        if self.datasource.type == 'table':
            # fixme 考虑取消绑定的数据库
            db_engine_spec = self.datasource.database.db_engine_spec
            metric = db_engine_spec.mutate_expression_label(metric)
        return metric

    @staticmethod
    def handle_js_int_overflow(data):
        # 处理数值过大,过大数字会转化为字符串, js中最大数值有限制
        for d in data.get('records', dict()):
            for k, v in list(d.items()):
                if isinstance(v, int):
                    # if an int is too big for Java Script to handle
                    # convert it to a string
                    if abs(v) > JS_MAX_INTEGER:
                        # JS_MAX_INTEGER = 9007199254740991   # Largest int Java Script can handle 2^53-1
                        d[k] = str(v)
        return data

    def run_extra_queries(self):  # 进行多个请求
        pass

    def handle_nulls(self, df):
        # 处理空数据
        fillna = self.get_fillna_for_columns(df.columns)
        return df.fillna(fillna)

    def get_fillna_for_col(self, col):
        """Returns the value to use as filler for a specific Column.type"""
        if col:
            if col.is_string:
                return 'NULL'
        return self.default_fillna  # 0

    def get_fillna_for_columns(self, columns=None):
        """Returns a dict or scalar that can be passed to DataFrame.fillna"""
        if columns is None:
            return self.default_fillna  # self.default_fillna = 0
        # 列表解析式,直接构造 字典类型的数据
        columns_dict = {col.column_name: col for col in self.datasource.columns}
        fillna = {
            c: self.get_fillna_for_col(columns_dict.get(c))  # 返回 0
            for c in columns
        }
        return fillna

    def get_samples(self):
        query_obj = self.query_obj()
        query_obj.update({
            'groupby': [],
            'metrics': [],
            'row_limit': 1000,
            'columns': [o.column_name for o in self.datasource.columns],
            # ??
        })
        # 通过查询获取数据
        df = self.get_df(query_obj)
        return df.to_dict(orient='records')

    def get_df(self, query_obj=None):
        """Returns a pandas dataframe based on the query object"""
        if not query_obj:
            # print(self.query_obj())
            query_obj = self.query_obj()
        if not query_obj:
            return None

        # 报错返回信息 空
        self.error_msg = ''

        timestamp_format = None
        if self.datasource.type == 'table':
            dttm_col = self.datasource.get_col(query_obj['granularity'])  # granularity 间隔,粒度
            if dttm_col:
                timestamp_format = dttm_col.python_date_format  # ??

        # The datasource here can be different backend(后端的) but the interface(借口) is common
        self.results = self.datasource.query(query_obj)
        self.query = self.results.query
        self.status = self.results.status
        self.error_message = self.results.error_message

        df = self.results.df
        # Transform the timestamp we received from database to pandas supported ==> 转化日期格式为 pandas内的日期格式
        # datetime format. If no python_date_format is specified, the pattern will
        # be considered as the default ISO date format
        # If the datetime format is unix, the parse will use the corresponding
        # parsing logic.
        if df is not None and not df.empty:
            # DTTM_ALIAS = '__timestamp' 已经定义了
            if DTTM_ALIAS in df.columns:
                if timestamp_format in ('epoch_s', 'epoch_ms'):
                    # Column has already been formatted as a timestamp.
                    # 转化为pd的时间戳
                    df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(pd.Timestamp)
                else:
                    # pd的日期
                    df[DTTM_ALIAS] = pd.to_datetime(
                        df[DTTM_ALIAS], utc=False, format=timestamp_format)
                # Offset models里有定义
                if self.datasource.offset:
                    df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset)
                df[DTTM_ALIAS] += self.time_shift

            if self.enforce_numerical_metrics:
                self.df_metrics_to_num(df)

            # 用空值 替换 无穷大、 去穷小的数据
            df.replace([np.inf, -np.inf], np.nan)
            df = self.handle_nulls(df)
        return df

    def df_metrics_to_num(self, df):
        """Converting metrics to numeric when pandas.read_sql cannot"""
        metrics = self.metric_labels
        for col, dtype in df.dtypes.items():
            # 列名、列数据类型
            if dtype.type == np.object_ and col in metrics:
                # 将该列数据转化为整数型类型, pd.to_numeirc()
                df[col] = pd.to_numeric(df[col], errors='coerce')

    def process_query_filters(self):
        # 处理过滤器
        utils.convert_legacy_filters_into_adhoc(self.form_data)  # filters ==> adhoc_filters , 并删除 filters
        merge_extra_filters(self.form_data)
        utils.split_adhoc_filters_into_base_filters(self.form_data)

    def query_obj(self):
        """Building a query object"""
        form_data = self.form_data
        # 先处理了adhoc_filters
        self.process_query_filters()
        gb = form_data.get('groupby') or []
        metrics = self.all_metrics or []
        columns = form_data.get('columns') or []
        # print('query_oj', gb)
        groupby = []
        for o in gb + columns:
            if o not in groupby:
                groupby.append(o)

        is_timeseries = self.is_timeseries
        if DTTM_ALIAS in groupby:
            groupby.remove(DTTM_ALIAS)
            is_timeseries = True

        granularity = (
                form_data.get('granularity') or
                form_data.get('granularity_sqla')
        )
        limit = int(form_data.get('limit') or 0)
        timeseries_limit_metric = form_data.get('timeseries_limit_metric')
        row_limit = int(form_data.get('row_limit') or config.get('ROW_LIMIT'))

        # default order direction
        order_desc = form_data.get('order_desc', True)

        since, until = utils.get_since_until(form_data.get('time_range'),
                                             form_data.get('since'),
                                             form_data.get('until'))
        # get_since_until() 方法:处理时间条件,设置时间间隔、起始、终止时间
        time_shift = form_data.get('time_shift', '')
        self.time_shift = utils.parse_human_timedelta(time_shift)
        from_dttm = None if since is None else (since - self.time_shift)
        to_dttm = None if until is None else (until - self.time_shift)
        if from_dttm and to_dttm and from_dttm > to_dttm:
            raise Exception(_('From date cannot be larger than to date'))

        self.from_dttm = from_dttm
        self.to_dttm = to_dttm

        # extras are used to query elements specific to a datasource type
        # for instance the extra where clause that applies only to Tables
        extras = {
            'where': form_data.get('where', ''),
            'having': form_data.get('having', ''),
            'having_druid': form_data.get('having_filters', []),
            'time_grain_sqla': form_data.get('time_grain_sqla', ''),
            'druid_time_origin': form_data.get('druid_time_origin', ''),
        }

        d = {
            'granularity': granularity,
            'from_dttm': from_dttm,
            'to_dttm': to_dttm,
            'is_timeseries': is_timeseries,
            'groupby': groupby,
            'metrics': metrics,
            'row_limit': row_limit,
            'filter': self.form_data.get('filters', []),
            'timeseries_limit': limit,
            'extras': extras,
            'timeseries_limit_metric': timeseries_limit_metric,
            'order_desc': order_desc,
            'prequeries': [],
            'is_prequery': False,
        }
        return d

    @property
    def cache_timeout(self):
        if self.form_data.get('cache_timeout') is not None:
            return int(self.form_data.get('cache_timeout'))
        if self.datasource.cache_timeout is not None:
            return self.datasource.cache_timeout
        if (
                hasattr(self.datasource, 'database') and
                self.datasource.database.cache_timeout) is not None:
            return self.datasource.database.cache_timeout
        return config.get('CACHE_DEFAULT_TIMEOUT')

    def get_json(self):
        return json.dumps(
            self.get_payload(),
            default=utils.json_int_dttm_ser, ignore_nan=True)

    def cache_key(self, query_obj, **extra):
        """
        The cache key is made out of the key/values in `query_obj`, plus any
        other key/values in `extra`.

        We remove datetime bounds that are hard values, and replace them with
        the use-provided inputs to bounds, which may be time-relative (as in
        "5 days ago" or "now").

        The `extra` arguments are currently used by time shift queries, since
        different time shifts wil differ only in the `from_dttm` and `to_dttm`
        values which are stripped.
        """
        cache_dict = copy.copy(query_obj)
        cache_dict.update(extra)

        for k in ['from_dttm', 'to_dttm']:
            del cache_dict[k]

        cache_dict['time_range'] = self.form_data.get('time_range')
        cache_dict['datasource'] = self.datasource.uid
        json_data = self.json_dumps(cache_dict, sort_keys=True)
        return hashlib.md5(json_data.encode('utf-8')).hexdigest()

    def get_payload(self, query_obj=None):
        """Returns a payload of metadata and data"""
        self.run_extra_queries()  #该方法目前为空;没有具体的实现内容
        payload = self.get_df_payload(query_obj)  # 核心部分,用查询条件获取数据

        df = payload.get('df')
        if self.status != utils.QueryStatus.FAILED:
            if df is not None and df.empty:
                payload['error'] = 'No data'
            else:
                # 调用get_data对查询数据进一步处理
                payload['data'] = self.get_data(df)
        if 'df' in payload:
            del payload['df']
        return payload

    def get_df_payload(self, query_obj=None, **kwargs):
        """Handles caching around the df payload retrieval"""
        if not query_obj:
            query_obj = self.query_obj()
        cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None
        logging.info('Cache key: {}'.format(cache_key))
        is_loaded = False
        stacktrace = None
        df = None
        # 获取当前时间
        cached_dttm = datetime.utcnow().isoformat().split('.')[0]
        if cache_key and cache and not self.force:
            cache_value = cache.get(cache_key)
            if cache_value:
                stats_logger.incr('loaded_from_cache')
                try:
                    cache_value = pkl.loads(cache_value)
                    df = cache_value['df']
                    self.query = cache_value['query']
                    self._any_cached_dttm = cache_value['dttm']
                    self._any_cache_key = cache_key
                    self.status = utils.QueryStatus.SUCCESS
                    # 标识 加载成功
                    is_loaded = True
                except Exception as e:
                    logging.exception(e)
                    logging.error('Error reading cache: ' +
                                  utils.error_msg_from_exception(e))
                logging.info('Serving from cache')

        if query_obj and not is_loaded:
            # 条件: is_loaded = False, 即上一步未成功
            try:
                df = self.get_df(query_obj)
                if self.status != utils.QueryStatus.FAILED:
                    stats_logger.incr('loaded_from_source')
                    is_loaded = True
            except Exception as e:
                logging.exception(e)
                if not self.error_message:
                    self.error_message = '{}'.format(e)
                self.status = utils.QueryStatus.FAILED
                stacktrace = traceback.format_exc()

            if (
                    is_loaded and
                    cache_key and
                    cache and
                    self.status != utils.QueryStatus.FAILED):
                try:
                    cache_value = dict(
                        dttm=cached_dttm,
                        df=df if df is not None else None,
                        query=self.query,
                    )
                    cache_value = pkl.dumps(
                        cache_value, protocol=pkl.HIGHEST_PROTOCOL)

                    logging.info('Caching {} chars at key {}'.format(
                        len(cache_value), cache_key))

                    stats_logger.incr('set_cache_key')
                    cache.set(
                        cache_key,
                        cache_value,
                        timeout=self.cache_timeout)
                except Exception as e:
                    # cache.set call can fail if the backend is down or if
                    # the key is too large or whatever other reasons
                    logging.warning('Could not cache key {}'.format(cache_key))
                    logging.exception(e)
                    cache.delete(cache_key)
        return {
            'cache_key': self._any_cache_key,
            'cached_dttm': self._any_cached_dttm,
            'cache_timeout': self.cache_timeout,
            'df': df,
            'error': self.error_message,
            'form_data': self.form_data,
            'is_cached': self._any_cache_key is not None,
            'query': self.query,
            'status': self.status,
            'stacktrace': stacktrace,
            'rowcount': len(df.index) if df is not None else 0,
        }

    # 读取json数据
    def json_dumps(self, obj, sort_keys=False):
        return json.dumps(
            obj,
            default=utils.json_int_dttm_ser,
            ignore_nan=True,
            sort_keys=sort_keys,
        )

    def payload_json_and_has_error(self, payload):
        has_error = payload.get('status') == utils.QueryStatus.FAILED or 
                    payload.get('error') is not None
        return self.json_dumps(payload), has_error

    @property
    def data(self):
        """This is the data object serialized to the js layer"""
        # 获取数据
        content = {
            'form_data': self.form_data,
            'token': self.token,
            'viz_name': self.viz_type,
            'filter_select_enabled': self.datasource.filter_select_enabled,
        }
        return content

    def get_csv(self):
        df = self.get_df()
        #  isinstance() 判断参数类型;获取非 pd.RangeIndex 类型数据的参数
        include_index = not isinstance(df.index, pd.RangeIndex)
        # 输出csv格式数据
        return df.to_csv(index=include_index, **config.get('CSV_EXPORT'))

    def get_data(self, df):
        # df 格式数据 => dict
        return df.to_dict(orient='records')

    @property
    def json_data(self):
        return json.dumps(self.data)
#该方法目前为空;没有具体的实现内容
原文地址:https://www.cnblogs.com/bennyjane/p/12702016.html